Pass prompt as system_instruction for Gemini 1.5 models (#120147)

This commit is contained in:
tronikos 2024-06-22 03:35:48 -07:00 committed by GitHub
parent 57eb8dab6a
commit ad1f0db5a4
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 253 additions and 141 deletions

View File

@ -161,10 +161,14 @@ class GoogleGenerativeAIConversationEntity(
self, user_input: conversation.ConversationInput self, user_input: conversation.ConversationInput
) -> conversation.ConversationResult: ) -> conversation.ConversationResult:
"""Process a sentence.""" """Process a sentence."""
intent_response = intent.IntentResponse(language=user_input.language) result = conversation.ConversationResult(
llm_api: llm.APIInstance | None = None response=intent.IntentResponse(language=user_input.language),
tools: list[dict[str, Any]] | None = None conversation_id=user_input.conversation_id
user_name: str | None = None if user_input.conversation_id in self.history
else ulid.ulid_now(),
)
assert result.conversation_id
llm_context = llm.LLMContext( llm_context = llm.LLMContext(
platform=DOMAIN, platform=DOMAIN,
context=user_input.context, context=user_input.context,
@ -173,7 +177,8 @@ class GoogleGenerativeAIConversationEntity(
assistant=conversation.DOMAIN, assistant=conversation.DOMAIN,
device_id=user_input.device_id, device_id=user_input.device_id,
) )
llm_api: llm.APIInstance | None = None
tools: list[dict[str, Any]] | None = None
if self.entry.options.get(CONF_LLM_HASS_API): if self.entry.options.get(CONF_LLM_HASS_API):
try: try:
llm_api = await llm.async_get_api( llm_api = await llm.async_get_api(
@ -183,17 +188,33 @@ class GoogleGenerativeAIConversationEntity(
) )
except HomeAssistantError as err: except HomeAssistantError as err:
LOGGER.error("Error getting LLM API: %s", err) LOGGER.error("Error getting LLM API: %s", err)
intent_response.async_set_error( result.response.async_set_error(
intent.IntentResponseErrorCode.UNKNOWN, intent.IntentResponseErrorCode.UNKNOWN,
f"Error preparing LLM API: {err}", f"Error preparing LLM API: {err}",
) )
return conversation.ConversationResult( return result
response=intent_response, conversation_id=user_input.conversation_id
)
tools = [_format_tool(tool) for tool in llm_api.tools] tools = [_format_tool(tool) for tool in llm_api.tools]
try:
prompt = await self._async_render_prompt(user_input, llm_api, llm_context)
except TemplateError as err:
LOGGER.error("Error rendering prompt: %s", err)
result.response.async_set_error(
intent.IntentResponseErrorCode.UNKNOWN,
f"Sorry, I had a problem with my template: {err}",
)
return result
model_name = self.entry.options.get(CONF_CHAT_MODEL, RECOMMENDED_CHAT_MODEL)
# Gemini 1.0 doesn't support system_instruction while 1.5 does.
# Assume future versions will support it (if not, the request fails with a
# clear message at which point we can fix).
supports_system_instruction = (
"gemini-1.0" not in model_name and "gemini-pro" not in model_name
)
model = genai.GenerativeModel( model = genai.GenerativeModel(
model_name=self.entry.options.get(CONF_CHAT_MODEL, RECOMMENDED_CHAT_MODEL), model_name=model_name,
generation_config={ generation_config={
"temperature": self.entry.options.get( "temperature": self.entry.options.get(
CONF_TEMPERATURE, RECOMMENDED_TEMPERATURE CONF_TEMPERATURE, RECOMMENDED_TEMPERATURE
@ -219,69 +240,25 @@ class GoogleGenerativeAIConversationEntity(
), ),
}, },
tools=tools or None, tools=tools or None,
system_instruction=prompt if supports_system_instruction else None,
) )
if user_input.conversation_id in self.history: messages = self.history.get(result.conversation_id, [])
conversation_id = user_input.conversation_id if not supports_system_instruction:
messages = self.history[conversation_id] if not messages:
else:
conversation_id = ulid.ulid_now()
messages = [{}, {"role": "model", "parts": "Ok"}] messages = [{}, {"role": "model", "parts": "Ok"}]
messages[0] = {"role": "user", "parts": prompt}
if (
user_input.context
and user_input.context.user_id
and (
user := await self.hass.auth.async_get_user(user_input.context.user_id)
)
):
user_name = user.name
try:
if llm_api:
api_prompt = llm_api.api_prompt
else:
api_prompt = llm.async_render_no_api_prompt(self.hass)
prompt = "\n".join(
(
template.Template(
llm.BASE_PROMPT
+ self.entry.options.get(
CONF_PROMPT, llm.DEFAULT_INSTRUCTIONS_PROMPT
),
self.hass,
).async_render(
{
"ha_name": self.hass.config.location_name,
"user_name": user_name,
"llm_context": llm_context,
},
parse_result=False,
),
api_prompt,
)
)
except TemplateError as err:
LOGGER.error("Error rendering prompt: %s", err)
intent_response.async_set_error(
intent.IntentResponseErrorCode.UNKNOWN,
f"Sorry, I had a problem with my template: {err}",
)
return conversation.ConversationResult(
response=intent_response, conversation_id=conversation_id
)
# Make a copy, because we attach it to the trace event.
messages = [
{"role": "user", "parts": prompt},
*messages[1:],
]
LOGGER.debug("Input: '%s' with history: %s", user_input.text, messages) LOGGER.debug("Input: '%s' with history: %s", user_input.text, messages)
trace.async_conversation_trace_append( trace.async_conversation_trace_append(
trace.ConversationTraceEventType.AGENT_DETAIL, {"messages": messages} trace.ConversationTraceEventType.AGENT_DETAIL,
{
# Make a copy to attach it to the trace event.
"messages": messages[:]
if supports_system_instruction
else messages[2:],
"prompt": prompt,
},
) )
chat = model.start_chat(history=messages) chat = model.start_chat(history=messages)
@ -307,24 +284,20 @@ class GoogleGenerativeAIConversationEntity(
f"Sorry, I had a problem talking to Google Generative AI: {err}" f"Sorry, I had a problem talking to Google Generative AI: {err}"
) )
intent_response.async_set_error( result.response.async_set_error(
intent.IntentResponseErrorCode.UNKNOWN, intent.IntentResponseErrorCode.UNKNOWN,
error, error,
) )
return conversation.ConversationResult( return result
response=intent_response, conversation_id=conversation_id
)
LOGGER.debug("Response: %s", chat_response.parts) LOGGER.debug("Response: %s", chat_response.parts)
if not chat_response.parts: if not chat_response.parts:
intent_response.async_set_error( result.response.async_set_error(
intent.IntentResponseErrorCode.UNKNOWN, intent.IntentResponseErrorCode.UNKNOWN,
"Sorry, I had a problem getting a response from Google Generative AI.", "Sorry, I had a problem getting a response from Google Generative AI.",
) )
return conversation.ConversationResult( return result
response=intent_response, conversation_id=conversation_id self.history[result.conversation_id] = chat.history
)
self.history[conversation_id] = chat.history
function_calls = [ function_calls = [
part.function_call for part in chat_response.parts if part.function_call part.function_call for part in chat_response.parts if part.function_call
] ]
@ -355,9 +328,48 @@ class GoogleGenerativeAIConversationEntity(
) )
chat_request = protos.Content(parts=tool_responses) chat_request = protos.Content(parts=tool_responses)
intent_response.async_set_speech( result.response.async_set_speech(
" ".join([part.text.strip() for part in chat_response.parts if part.text]) " ".join([part.text.strip() for part in chat_response.parts if part.text])
) )
return conversation.ConversationResult( return result
response=intent_response, conversation_id=conversation_id
async def _async_render_prompt(
self,
user_input: conversation.ConversationInput,
llm_api: llm.APIInstance | None,
llm_context: llm.LLMContext,
) -> str:
user_name: str | None = None
if (
user_input.context
and user_input.context.user_id
and (
user := await self.hass.auth.async_get_user(user_input.context.user_id)
)
):
user_name = user.name
if llm_api:
api_prompt = llm_api.api_prompt
else:
api_prompt = llm.async_render_no_api_prompt(self.hass)
return "\n".join(
(
template.Template(
llm.BASE_PROMPT
+ self.entry.options.get(
CONF_PROMPT, llm.DEFAULT_INSTRUCTIONS_PROMPT
),
self.hass,
).async_render(
{
"ha_name": self.hass.config.location_name,
"user_name": user_name,
"llm_context": llm_context,
},
parse_result=False,
),
api_prompt,
)
) )

View File

@ -43,6 +43,7 @@ BASE_PROMPT = (
) )
DEFAULT_INSTRUCTIONS_PROMPT = """You are a voice assistant for Home Assistant. DEFAULT_INSTRUCTIONS_PROMPT = """You are a voice assistant for Home Assistant.
Answer questions about the world truthfully.
Answer in plain text. Keep it simple and to the point. Answer in plain text. Keep it simple and to the point.
""" """

View File

@ -1,5 +1,5 @@
# serializer version: 1 # serializer version: 1
# name: test_chat_history # name: test_chat_history[models/gemini-1.0-pro-False]
list([ list([
tuple( tuple(
'', '',
@ -12,13 +12,14 @@
'top_k': 64, 'top_k': 64,
'top_p': 0.95, 'top_p': 0.95,
}), }),
'model_name': 'models/gemini-1.5-flash-latest', 'model_name': 'models/gemini-1.0-pro',
'safety_settings': dict({ 'safety_settings': dict({
'DANGEROUS': 'BLOCK_MEDIUM_AND_ABOVE', 'DANGEROUS': 'BLOCK_MEDIUM_AND_ABOVE',
'HARASSMENT': 'BLOCK_MEDIUM_AND_ABOVE', 'HARASSMENT': 'BLOCK_MEDIUM_AND_ABOVE',
'HATE': 'BLOCK_MEDIUM_AND_ABOVE', 'HATE': 'BLOCK_MEDIUM_AND_ABOVE',
'SEXUAL': 'BLOCK_MEDIUM_AND_ABOVE', 'SEXUAL': 'BLOCK_MEDIUM_AND_ABOVE',
}), }),
'system_instruction': None,
'tools': None, 'tools': None,
}), }),
), ),
@ -32,6 +33,7 @@
'parts': ''' 'parts': '''
Current time is 05:00:00. Today's date is 2024-05-24. Current time is 05:00:00. Today's date is 2024-05-24.
You are a voice assistant for Home Assistant. You are a voice assistant for Home Assistant.
Answer questions about the world truthfully.
Answer in plain text. Keep it simple and to the point. Answer in plain text. Keep it simple and to the point.
Only if the user wants to control a device, tell them to edit the AI configuration and allow access to Home Assistant. Only if the user wants to control a device, tell them to edit the AI configuration and allow access to Home Assistant.
''', ''',
@ -63,13 +65,14 @@
'top_k': 64, 'top_k': 64,
'top_p': 0.95, 'top_p': 0.95,
}), }),
'model_name': 'models/gemini-1.5-flash-latest', 'model_name': 'models/gemini-1.0-pro',
'safety_settings': dict({ 'safety_settings': dict({
'DANGEROUS': 'BLOCK_MEDIUM_AND_ABOVE', 'DANGEROUS': 'BLOCK_MEDIUM_AND_ABOVE',
'HARASSMENT': 'BLOCK_MEDIUM_AND_ABOVE', 'HARASSMENT': 'BLOCK_MEDIUM_AND_ABOVE',
'HATE': 'BLOCK_MEDIUM_AND_ABOVE', 'HATE': 'BLOCK_MEDIUM_AND_ABOVE',
'SEXUAL': 'BLOCK_MEDIUM_AND_ABOVE', 'SEXUAL': 'BLOCK_MEDIUM_AND_ABOVE',
}), }),
'system_instruction': None,
'tools': None, 'tools': None,
}), }),
), ),
@ -83,6 +86,7 @@
'parts': ''' 'parts': '''
Current time is 05:00:00. Today's date is 2024-05-24. Current time is 05:00:00. Today's date is 2024-05-24.
You are a voice assistant for Home Assistant. You are a voice assistant for Home Assistant.
Answer questions about the world truthfully.
Answer in plain text. Keep it simple and to the point. Answer in plain text. Keep it simple and to the point.
Only if the user wants to control a device, tell them to edit the AI configuration and allow access to Home Assistant. Only if the user wants to control a device, tell them to edit the AI configuration and allow access to Home Assistant.
''', ''',
@ -113,6 +117,108 @@
), ),
]) ])
# --- # ---
# name: test_chat_history[models/gemini-1.5-pro-True]
list([
tuple(
'',
tuple(
),
dict({
'generation_config': dict({
'max_output_tokens': 150,
'temperature': 1.0,
'top_k': 64,
'top_p': 0.95,
}),
'model_name': 'models/gemini-1.5-pro',
'safety_settings': dict({
'DANGEROUS': 'BLOCK_MEDIUM_AND_ABOVE',
'HARASSMENT': 'BLOCK_MEDIUM_AND_ABOVE',
'HATE': 'BLOCK_MEDIUM_AND_ABOVE',
'SEXUAL': 'BLOCK_MEDIUM_AND_ABOVE',
}),
'system_instruction': '''
Current time is 05:00:00. Today's date is 2024-05-24.
You are a voice assistant for Home Assistant.
Answer questions about the world truthfully.
Answer in plain text. Keep it simple and to the point.
Only if the user wants to control a device, tell them to edit the AI configuration and allow access to Home Assistant.
''',
'tools': None,
}),
),
tuple(
'().start_chat',
tuple(
),
dict({
'history': list([
]),
}),
),
tuple(
'().start_chat().send_message_async',
tuple(
'1st user request',
),
dict({
}),
),
tuple(
'',
tuple(
),
dict({
'generation_config': dict({
'max_output_tokens': 150,
'temperature': 1.0,
'top_k': 64,
'top_p': 0.95,
}),
'model_name': 'models/gemini-1.5-pro',
'safety_settings': dict({
'DANGEROUS': 'BLOCK_MEDIUM_AND_ABOVE',
'HARASSMENT': 'BLOCK_MEDIUM_AND_ABOVE',
'HATE': 'BLOCK_MEDIUM_AND_ABOVE',
'SEXUAL': 'BLOCK_MEDIUM_AND_ABOVE',
}),
'system_instruction': '''
Current time is 05:00:00. Today's date is 2024-05-24.
You are a voice assistant for Home Assistant.
Answer questions about the world truthfully.
Answer in plain text. Keep it simple and to the point.
Only if the user wants to control a device, tell them to edit the AI configuration and allow access to Home Assistant.
''',
'tools': None,
}),
),
tuple(
'().start_chat',
tuple(
),
dict({
'history': list([
dict({
'parts': '1st user request',
'role': 'user',
}),
dict({
'parts': '1st model response',
'role': 'model',
}),
]),
}),
),
tuple(
'().start_chat().send_message_async',
tuple(
'2nd user request',
),
dict({
}),
),
])
# ---
# name: test_default_prompt[config_entry_options0-None] # name: test_default_prompt[config_entry_options0-None]
list([ list([
tuple( tuple(
@ -133,6 +239,13 @@
'HATE': 'BLOCK_MEDIUM_AND_ABOVE', 'HATE': 'BLOCK_MEDIUM_AND_ABOVE',
'SEXUAL': 'BLOCK_MEDIUM_AND_ABOVE', 'SEXUAL': 'BLOCK_MEDIUM_AND_ABOVE',
}), }),
'system_instruction': '''
Current time is 05:00:00. Today's date is 2024-05-24.
You are a voice assistant for Home Assistant.
Answer questions about the world truthfully.
Answer in plain text. Keep it simple and to the point.
<no_api_prompt>
''',
'tools': None, 'tools': None,
}), }),
), ),
@ -142,19 +255,6 @@
), ),
dict({ dict({
'history': list([ 'history': list([
dict({
'parts': '''
Current time is 05:00:00. Today's date is 2024-05-24.
You are a voice assistant for Home Assistant.
Answer in plain text. Keep it simple and to the point.
<no_api_prompt>
''',
'role': 'user',
}),
dict({
'parts': 'Ok',
'role': 'model',
}),
]), ]),
}), }),
), ),
@ -188,6 +288,13 @@
'HATE': 'BLOCK_MEDIUM_AND_ABOVE', 'HATE': 'BLOCK_MEDIUM_AND_ABOVE',
'SEXUAL': 'BLOCK_MEDIUM_AND_ABOVE', 'SEXUAL': 'BLOCK_MEDIUM_AND_ABOVE',
}), }),
'system_instruction': '''
Current time is 05:00:00. Today's date is 2024-05-24.
You are a voice assistant for Home Assistant.
Answer questions about the world truthfully.
Answer in plain text. Keep it simple and to the point.
<no_api_prompt>
''',
'tools': None, 'tools': None,
}), }),
), ),
@ -197,19 +304,6 @@
), ),
dict({ dict({
'history': list([ 'history': list([
dict({
'parts': '''
Current time is 05:00:00. Today's date is 2024-05-24.
You are a voice assistant for Home Assistant.
Answer in plain text. Keep it simple and to the point.
<no_api_prompt>
''',
'role': 'user',
}),
dict({
'parts': 'Ok',
'role': 'model',
}),
]), ]),
}), }),
), ),
@ -243,6 +337,13 @@
'HATE': 'BLOCK_MEDIUM_AND_ABOVE', 'HATE': 'BLOCK_MEDIUM_AND_ABOVE',
'SEXUAL': 'BLOCK_MEDIUM_AND_ABOVE', 'SEXUAL': 'BLOCK_MEDIUM_AND_ABOVE',
}), }),
'system_instruction': '''
Current time is 05:00:00. Today's date is 2024-05-24.
You are a voice assistant for Home Assistant.
Answer questions about the world truthfully.
Answer in plain text. Keep it simple and to the point.
<api_prompt>
''',
'tools': None, 'tools': None,
}), }),
), ),
@ -252,19 +353,6 @@
), ),
dict({ dict({
'history': list([ 'history': list([
dict({
'parts': '''
Current time is 05:00:00. Today's date is 2024-05-24.
You are a voice assistant for Home Assistant.
Answer in plain text. Keep it simple and to the point.
<api_prompt>
''',
'role': 'user',
}),
dict({
'parts': 'Ok',
'role': 'model',
}),
]), ]),
}), }),
), ),
@ -298,6 +386,13 @@
'HATE': 'BLOCK_MEDIUM_AND_ABOVE', 'HATE': 'BLOCK_MEDIUM_AND_ABOVE',
'SEXUAL': 'BLOCK_MEDIUM_AND_ABOVE', 'SEXUAL': 'BLOCK_MEDIUM_AND_ABOVE',
}), }),
'system_instruction': '''
Current time is 05:00:00. Today's date is 2024-05-24.
You are a voice assistant for Home Assistant.
Answer questions about the world truthfully.
Answer in plain text. Keep it simple and to the point.
<api_prompt>
''',
'tools': None, 'tools': None,
}), }),
), ),
@ -307,19 +402,6 @@
), ),
dict({ dict({
'history': list([ 'history': list([
dict({
'parts': '''
Current time is 05:00:00. Today's date is 2024-05-24.
You are a voice assistant for Home Assistant.
Answer in plain text. Keep it simple and to the point.
<api_prompt>
''',
'role': 'user',
}),
dict({
'parts': 'Ok',
'role': 'model',
}),
]), ]),
}), }),
), ),

View File

@ -12,6 +12,9 @@ import voluptuous as vol
from homeassistant.components import conversation from homeassistant.components import conversation
from homeassistant.components.conversation import trace from homeassistant.components.conversation import trace
from homeassistant.components.google_generative_ai_conversation.const import (
CONF_CHAT_MODEL,
)
from homeassistant.components.google_generative_ai_conversation.conversation import ( from homeassistant.components.google_generative_ai_conversation.conversation import (
_escape_decode, _escape_decode,
) )
@ -99,13 +102,22 @@ async def test_default_prompt(
assert mock_get_tools.called == (CONF_LLM_HASS_API in config_entry_options) assert mock_get_tools.called == (CONF_LLM_HASS_API in config_entry_options)
@pytest.mark.parametrize(
("model_name", "supports_system_instruction"),
[("models/gemini-1.5-pro", True), ("models/gemini-1.0-pro", False)],
)
async def test_chat_history( async def test_chat_history(
hass: HomeAssistant, hass: HomeAssistant,
mock_config_entry: MockConfigEntry, mock_config_entry: MockConfigEntry,
mock_init_component, mock_init_component,
model_name: str,
supports_system_instruction: bool,
snapshot: SnapshotAssertion, snapshot: SnapshotAssertion,
) -> None: ) -> None:
"""Test that the agent keeps track of the chat history.""" """Test that the agent keeps track of the chat history."""
hass.config_entries.async_update_entry(
mock_config_entry, options={CONF_CHAT_MODEL: model_name}
)
with patch("google.generativeai.GenerativeModel") as mock_model: with patch("google.generativeai.GenerativeModel") as mock_model:
mock_chat = AsyncMock() mock_chat = AsyncMock()
mock_model.return_value.start_chat.return_value = mock_chat mock_model.return_value.start_chat.return_value = mock_chat
@ -115,9 +127,14 @@ async def test_chat_history(
mock_part.function_call = None mock_part.function_call = None
mock_part.text = "1st model response" mock_part.text = "1st model response"
chat_response.parts = [mock_part] chat_response.parts = [mock_part]
if supports_system_instruction:
mock_chat.history = []
else:
mock_chat.history = [ mock_chat.history = [
{"role": "user", "parts": "prompt"}, {"role": "user", "parts": "prompt"},
{"role": "model", "parts": "Ok"}, {"role": "model", "parts": "Ok"},
]
mock_chat.history += [
{"role": "user", "parts": "1st user request"}, {"role": "user", "parts": "1st user request"},
{"role": "model", "parts": "1st model response"}, {"role": "model", "parts": "1st model response"},
] ]
@ -256,7 +273,7 @@ async def test_function_call(
] ]
# AGENT_DETAIL event contains the raw prompt passed to the model # AGENT_DETAIL event contains the raw prompt passed to the model
detail_event = trace_events[1] detail_event = trace_events[1]
assert "Answer in plain text" in detail_event["data"]["messages"][0]["parts"] assert "Answer in plain text" in detail_event["data"]["prompt"]
@patch( @patch(
@ -492,9 +509,9 @@ async def test_template_variables(
), result ), result
assert ( assert (
"The user name is Test User." "The user name is Test User."
in mock_model.mock_calls[1][2]["history"][0]["parts"] in mock_model.mock_calls[0][2]["system_instruction"]
) )
assert "The user id is 12345." in mock_model.mock_calls[1][2]["history"][0]["parts"] assert "The user id is 12345." in mock_model.mock_calls[0][2]["system_instruction"]
async def test_conversation_agent( async def test_conversation_agent(