mirror of
https://github.com/home-assistant/core.git
synced 2025-07-21 12:17:07 +00:00
Pass prompt as system_instruction for Gemini 1.5 models (#120147)
This commit is contained in:
parent
57eb8dab6a
commit
ad1f0db5a4
@ -161,10 +161,14 @@ class GoogleGenerativeAIConversationEntity(
|
|||||||
self, user_input: conversation.ConversationInput
|
self, user_input: conversation.ConversationInput
|
||||||
) -> conversation.ConversationResult:
|
) -> conversation.ConversationResult:
|
||||||
"""Process a sentence."""
|
"""Process a sentence."""
|
||||||
intent_response = intent.IntentResponse(language=user_input.language)
|
result = conversation.ConversationResult(
|
||||||
llm_api: llm.APIInstance | None = None
|
response=intent.IntentResponse(language=user_input.language),
|
||||||
tools: list[dict[str, Any]] | None = None
|
conversation_id=user_input.conversation_id
|
||||||
user_name: str | None = None
|
if user_input.conversation_id in self.history
|
||||||
|
else ulid.ulid_now(),
|
||||||
|
)
|
||||||
|
assert result.conversation_id
|
||||||
|
|
||||||
llm_context = llm.LLMContext(
|
llm_context = llm.LLMContext(
|
||||||
platform=DOMAIN,
|
platform=DOMAIN,
|
||||||
context=user_input.context,
|
context=user_input.context,
|
||||||
@ -173,7 +177,8 @@ class GoogleGenerativeAIConversationEntity(
|
|||||||
assistant=conversation.DOMAIN,
|
assistant=conversation.DOMAIN,
|
||||||
device_id=user_input.device_id,
|
device_id=user_input.device_id,
|
||||||
)
|
)
|
||||||
|
llm_api: llm.APIInstance | None = None
|
||||||
|
tools: list[dict[str, Any]] | None = None
|
||||||
if self.entry.options.get(CONF_LLM_HASS_API):
|
if self.entry.options.get(CONF_LLM_HASS_API):
|
||||||
try:
|
try:
|
||||||
llm_api = await llm.async_get_api(
|
llm_api = await llm.async_get_api(
|
||||||
@ -183,17 +188,33 @@ class GoogleGenerativeAIConversationEntity(
|
|||||||
)
|
)
|
||||||
except HomeAssistantError as err:
|
except HomeAssistantError as err:
|
||||||
LOGGER.error("Error getting LLM API: %s", err)
|
LOGGER.error("Error getting LLM API: %s", err)
|
||||||
intent_response.async_set_error(
|
result.response.async_set_error(
|
||||||
intent.IntentResponseErrorCode.UNKNOWN,
|
intent.IntentResponseErrorCode.UNKNOWN,
|
||||||
f"Error preparing LLM API: {err}",
|
f"Error preparing LLM API: {err}",
|
||||||
)
|
)
|
||||||
return conversation.ConversationResult(
|
return result
|
||||||
response=intent_response, conversation_id=user_input.conversation_id
|
|
||||||
)
|
|
||||||
tools = [_format_tool(tool) for tool in llm_api.tools]
|
tools = [_format_tool(tool) for tool in llm_api.tools]
|
||||||
|
|
||||||
|
try:
|
||||||
|
prompt = await self._async_render_prompt(user_input, llm_api, llm_context)
|
||||||
|
except TemplateError as err:
|
||||||
|
LOGGER.error("Error rendering prompt: %s", err)
|
||||||
|
result.response.async_set_error(
|
||||||
|
intent.IntentResponseErrorCode.UNKNOWN,
|
||||||
|
f"Sorry, I had a problem with my template: {err}",
|
||||||
|
)
|
||||||
|
return result
|
||||||
|
|
||||||
|
model_name = self.entry.options.get(CONF_CHAT_MODEL, RECOMMENDED_CHAT_MODEL)
|
||||||
|
# Gemini 1.0 doesn't support system_instruction while 1.5 does.
|
||||||
|
# Assume future versions will support it (if not, the request fails with a
|
||||||
|
# clear message at which point we can fix).
|
||||||
|
supports_system_instruction = (
|
||||||
|
"gemini-1.0" not in model_name and "gemini-pro" not in model_name
|
||||||
|
)
|
||||||
|
|
||||||
model = genai.GenerativeModel(
|
model = genai.GenerativeModel(
|
||||||
model_name=self.entry.options.get(CONF_CHAT_MODEL, RECOMMENDED_CHAT_MODEL),
|
model_name=model_name,
|
||||||
generation_config={
|
generation_config={
|
||||||
"temperature": self.entry.options.get(
|
"temperature": self.entry.options.get(
|
||||||
CONF_TEMPERATURE, RECOMMENDED_TEMPERATURE
|
CONF_TEMPERATURE, RECOMMENDED_TEMPERATURE
|
||||||
@ -219,69 +240,25 @@ class GoogleGenerativeAIConversationEntity(
|
|||||||
),
|
),
|
||||||
},
|
},
|
||||||
tools=tools or None,
|
tools=tools or None,
|
||||||
|
system_instruction=prompt if supports_system_instruction else None,
|
||||||
)
|
)
|
||||||
|
|
||||||
if user_input.conversation_id in self.history:
|
messages = self.history.get(result.conversation_id, [])
|
||||||
conversation_id = user_input.conversation_id
|
if not supports_system_instruction:
|
||||||
messages = self.history[conversation_id]
|
if not messages:
|
||||||
else:
|
messages = [{}, {"role": "model", "parts": "Ok"}]
|
||||||
conversation_id = ulid.ulid_now()
|
messages[0] = {"role": "user", "parts": prompt}
|
||||||
messages = [{}, {"role": "model", "parts": "Ok"}]
|
|
||||||
|
|
||||||
if (
|
|
||||||
user_input.context
|
|
||||||
and user_input.context.user_id
|
|
||||||
and (
|
|
||||||
user := await self.hass.auth.async_get_user(user_input.context.user_id)
|
|
||||||
)
|
|
||||||
):
|
|
||||||
user_name = user.name
|
|
||||||
|
|
||||||
try:
|
|
||||||
if llm_api:
|
|
||||||
api_prompt = llm_api.api_prompt
|
|
||||||
else:
|
|
||||||
api_prompt = llm.async_render_no_api_prompt(self.hass)
|
|
||||||
|
|
||||||
prompt = "\n".join(
|
|
||||||
(
|
|
||||||
template.Template(
|
|
||||||
llm.BASE_PROMPT
|
|
||||||
+ self.entry.options.get(
|
|
||||||
CONF_PROMPT, llm.DEFAULT_INSTRUCTIONS_PROMPT
|
|
||||||
),
|
|
||||||
self.hass,
|
|
||||||
).async_render(
|
|
||||||
{
|
|
||||||
"ha_name": self.hass.config.location_name,
|
|
||||||
"user_name": user_name,
|
|
||||||
"llm_context": llm_context,
|
|
||||||
},
|
|
||||||
parse_result=False,
|
|
||||||
),
|
|
||||||
api_prompt,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
except TemplateError as err:
|
|
||||||
LOGGER.error("Error rendering prompt: %s", err)
|
|
||||||
intent_response.async_set_error(
|
|
||||||
intent.IntentResponseErrorCode.UNKNOWN,
|
|
||||||
f"Sorry, I had a problem with my template: {err}",
|
|
||||||
)
|
|
||||||
return conversation.ConversationResult(
|
|
||||||
response=intent_response, conversation_id=conversation_id
|
|
||||||
)
|
|
||||||
|
|
||||||
# Make a copy, because we attach it to the trace event.
|
|
||||||
messages = [
|
|
||||||
{"role": "user", "parts": prompt},
|
|
||||||
*messages[1:],
|
|
||||||
]
|
|
||||||
|
|
||||||
LOGGER.debug("Input: '%s' with history: %s", user_input.text, messages)
|
LOGGER.debug("Input: '%s' with history: %s", user_input.text, messages)
|
||||||
trace.async_conversation_trace_append(
|
trace.async_conversation_trace_append(
|
||||||
trace.ConversationTraceEventType.AGENT_DETAIL, {"messages": messages}
|
trace.ConversationTraceEventType.AGENT_DETAIL,
|
||||||
|
{
|
||||||
|
# Make a copy to attach it to the trace event.
|
||||||
|
"messages": messages[:]
|
||||||
|
if supports_system_instruction
|
||||||
|
else messages[2:],
|
||||||
|
"prompt": prompt,
|
||||||
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
chat = model.start_chat(history=messages)
|
chat = model.start_chat(history=messages)
|
||||||
@ -307,24 +284,20 @@ class GoogleGenerativeAIConversationEntity(
|
|||||||
f"Sorry, I had a problem talking to Google Generative AI: {err}"
|
f"Sorry, I had a problem talking to Google Generative AI: {err}"
|
||||||
)
|
)
|
||||||
|
|
||||||
intent_response.async_set_error(
|
result.response.async_set_error(
|
||||||
intent.IntentResponseErrorCode.UNKNOWN,
|
intent.IntentResponseErrorCode.UNKNOWN,
|
||||||
error,
|
error,
|
||||||
)
|
)
|
||||||
return conversation.ConversationResult(
|
return result
|
||||||
response=intent_response, conversation_id=conversation_id
|
|
||||||
)
|
|
||||||
|
|
||||||
LOGGER.debug("Response: %s", chat_response.parts)
|
LOGGER.debug("Response: %s", chat_response.parts)
|
||||||
if not chat_response.parts:
|
if not chat_response.parts:
|
||||||
intent_response.async_set_error(
|
result.response.async_set_error(
|
||||||
intent.IntentResponseErrorCode.UNKNOWN,
|
intent.IntentResponseErrorCode.UNKNOWN,
|
||||||
"Sorry, I had a problem getting a response from Google Generative AI.",
|
"Sorry, I had a problem getting a response from Google Generative AI.",
|
||||||
)
|
)
|
||||||
return conversation.ConversationResult(
|
return result
|
||||||
response=intent_response, conversation_id=conversation_id
|
self.history[result.conversation_id] = chat.history
|
||||||
)
|
|
||||||
self.history[conversation_id] = chat.history
|
|
||||||
function_calls = [
|
function_calls = [
|
||||||
part.function_call for part in chat_response.parts if part.function_call
|
part.function_call for part in chat_response.parts if part.function_call
|
||||||
]
|
]
|
||||||
@ -355,9 +328,48 @@ class GoogleGenerativeAIConversationEntity(
|
|||||||
)
|
)
|
||||||
chat_request = protos.Content(parts=tool_responses)
|
chat_request = protos.Content(parts=tool_responses)
|
||||||
|
|
||||||
intent_response.async_set_speech(
|
result.response.async_set_speech(
|
||||||
" ".join([part.text.strip() for part in chat_response.parts if part.text])
|
" ".join([part.text.strip() for part in chat_response.parts if part.text])
|
||||||
)
|
)
|
||||||
return conversation.ConversationResult(
|
return result
|
||||||
response=intent_response, conversation_id=conversation_id
|
|
||||||
|
async def _async_render_prompt(
|
||||||
|
self,
|
||||||
|
user_input: conversation.ConversationInput,
|
||||||
|
llm_api: llm.APIInstance | None,
|
||||||
|
llm_context: llm.LLMContext,
|
||||||
|
) -> str:
|
||||||
|
user_name: str | None = None
|
||||||
|
if (
|
||||||
|
user_input.context
|
||||||
|
and user_input.context.user_id
|
||||||
|
and (
|
||||||
|
user := await self.hass.auth.async_get_user(user_input.context.user_id)
|
||||||
|
)
|
||||||
|
):
|
||||||
|
user_name = user.name
|
||||||
|
|
||||||
|
if llm_api:
|
||||||
|
api_prompt = llm_api.api_prompt
|
||||||
|
else:
|
||||||
|
api_prompt = llm.async_render_no_api_prompt(self.hass)
|
||||||
|
|
||||||
|
return "\n".join(
|
||||||
|
(
|
||||||
|
template.Template(
|
||||||
|
llm.BASE_PROMPT
|
||||||
|
+ self.entry.options.get(
|
||||||
|
CONF_PROMPT, llm.DEFAULT_INSTRUCTIONS_PROMPT
|
||||||
|
),
|
||||||
|
self.hass,
|
||||||
|
).async_render(
|
||||||
|
{
|
||||||
|
"ha_name": self.hass.config.location_name,
|
||||||
|
"user_name": user_name,
|
||||||
|
"llm_context": llm_context,
|
||||||
|
},
|
||||||
|
parse_result=False,
|
||||||
|
),
|
||||||
|
api_prompt,
|
||||||
|
)
|
||||||
)
|
)
|
||||||
|
@ -43,6 +43,7 @@ BASE_PROMPT = (
|
|||||||
)
|
)
|
||||||
|
|
||||||
DEFAULT_INSTRUCTIONS_PROMPT = """You are a voice assistant for Home Assistant.
|
DEFAULT_INSTRUCTIONS_PROMPT = """You are a voice assistant for Home Assistant.
|
||||||
|
Answer questions about the world truthfully.
|
||||||
Answer in plain text. Keep it simple and to the point.
|
Answer in plain text. Keep it simple and to the point.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
@ -1,5 +1,5 @@
|
|||||||
# serializer version: 1
|
# serializer version: 1
|
||||||
# name: test_chat_history
|
# name: test_chat_history[models/gemini-1.0-pro-False]
|
||||||
list([
|
list([
|
||||||
tuple(
|
tuple(
|
||||||
'',
|
'',
|
||||||
@ -12,13 +12,14 @@
|
|||||||
'top_k': 64,
|
'top_k': 64,
|
||||||
'top_p': 0.95,
|
'top_p': 0.95,
|
||||||
}),
|
}),
|
||||||
'model_name': 'models/gemini-1.5-flash-latest',
|
'model_name': 'models/gemini-1.0-pro',
|
||||||
'safety_settings': dict({
|
'safety_settings': dict({
|
||||||
'DANGEROUS': 'BLOCK_MEDIUM_AND_ABOVE',
|
'DANGEROUS': 'BLOCK_MEDIUM_AND_ABOVE',
|
||||||
'HARASSMENT': 'BLOCK_MEDIUM_AND_ABOVE',
|
'HARASSMENT': 'BLOCK_MEDIUM_AND_ABOVE',
|
||||||
'HATE': 'BLOCK_MEDIUM_AND_ABOVE',
|
'HATE': 'BLOCK_MEDIUM_AND_ABOVE',
|
||||||
'SEXUAL': 'BLOCK_MEDIUM_AND_ABOVE',
|
'SEXUAL': 'BLOCK_MEDIUM_AND_ABOVE',
|
||||||
}),
|
}),
|
||||||
|
'system_instruction': None,
|
||||||
'tools': None,
|
'tools': None,
|
||||||
}),
|
}),
|
||||||
),
|
),
|
||||||
@ -32,6 +33,7 @@
|
|||||||
'parts': '''
|
'parts': '''
|
||||||
Current time is 05:00:00. Today's date is 2024-05-24.
|
Current time is 05:00:00. Today's date is 2024-05-24.
|
||||||
You are a voice assistant for Home Assistant.
|
You are a voice assistant for Home Assistant.
|
||||||
|
Answer questions about the world truthfully.
|
||||||
Answer in plain text. Keep it simple and to the point.
|
Answer in plain text. Keep it simple and to the point.
|
||||||
Only if the user wants to control a device, tell them to edit the AI configuration and allow access to Home Assistant.
|
Only if the user wants to control a device, tell them to edit the AI configuration and allow access to Home Assistant.
|
||||||
''',
|
''',
|
||||||
@ -63,13 +65,14 @@
|
|||||||
'top_k': 64,
|
'top_k': 64,
|
||||||
'top_p': 0.95,
|
'top_p': 0.95,
|
||||||
}),
|
}),
|
||||||
'model_name': 'models/gemini-1.5-flash-latest',
|
'model_name': 'models/gemini-1.0-pro',
|
||||||
'safety_settings': dict({
|
'safety_settings': dict({
|
||||||
'DANGEROUS': 'BLOCK_MEDIUM_AND_ABOVE',
|
'DANGEROUS': 'BLOCK_MEDIUM_AND_ABOVE',
|
||||||
'HARASSMENT': 'BLOCK_MEDIUM_AND_ABOVE',
|
'HARASSMENT': 'BLOCK_MEDIUM_AND_ABOVE',
|
||||||
'HATE': 'BLOCK_MEDIUM_AND_ABOVE',
|
'HATE': 'BLOCK_MEDIUM_AND_ABOVE',
|
||||||
'SEXUAL': 'BLOCK_MEDIUM_AND_ABOVE',
|
'SEXUAL': 'BLOCK_MEDIUM_AND_ABOVE',
|
||||||
}),
|
}),
|
||||||
|
'system_instruction': None,
|
||||||
'tools': None,
|
'tools': None,
|
||||||
}),
|
}),
|
||||||
),
|
),
|
||||||
@ -83,6 +86,7 @@
|
|||||||
'parts': '''
|
'parts': '''
|
||||||
Current time is 05:00:00. Today's date is 2024-05-24.
|
Current time is 05:00:00. Today's date is 2024-05-24.
|
||||||
You are a voice assistant for Home Assistant.
|
You are a voice assistant for Home Assistant.
|
||||||
|
Answer questions about the world truthfully.
|
||||||
Answer in plain text. Keep it simple and to the point.
|
Answer in plain text. Keep it simple and to the point.
|
||||||
Only if the user wants to control a device, tell them to edit the AI configuration and allow access to Home Assistant.
|
Only if the user wants to control a device, tell them to edit the AI configuration and allow access to Home Assistant.
|
||||||
''',
|
''',
|
||||||
@ -113,6 +117,108 @@
|
|||||||
),
|
),
|
||||||
])
|
])
|
||||||
# ---
|
# ---
|
||||||
|
# name: test_chat_history[models/gemini-1.5-pro-True]
|
||||||
|
list([
|
||||||
|
tuple(
|
||||||
|
'',
|
||||||
|
tuple(
|
||||||
|
),
|
||||||
|
dict({
|
||||||
|
'generation_config': dict({
|
||||||
|
'max_output_tokens': 150,
|
||||||
|
'temperature': 1.0,
|
||||||
|
'top_k': 64,
|
||||||
|
'top_p': 0.95,
|
||||||
|
}),
|
||||||
|
'model_name': 'models/gemini-1.5-pro',
|
||||||
|
'safety_settings': dict({
|
||||||
|
'DANGEROUS': 'BLOCK_MEDIUM_AND_ABOVE',
|
||||||
|
'HARASSMENT': 'BLOCK_MEDIUM_AND_ABOVE',
|
||||||
|
'HATE': 'BLOCK_MEDIUM_AND_ABOVE',
|
||||||
|
'SEXUAL': 'BLOCK_MEDIUM_AND_ABOVE',
|
||||||
|
}),
|
||||||
|
'system_instruction': '''
|
||||||
|
Current time is 05:00:00. Today's date is 2024-05-24.
|
||||||
|
You are a voice assistant for Home Assistant.
|
||||||
|
Answer questions about the world truthfully.
|
||||||
|
Answer in plain text. Keep it simple and to the point.
|
||||||
|
Only if the user wants to control a device, tell them to edit the AI configuration and allow access to Home Assistant.
|
||||||
|
''',
|
||||||
|
'tools': None,
|
||||||
|
}),
|
||||||
|
),
|
||||||
|
tuple(
|
||||||
|
'().start_chat',
|
||||||
|
tuple(
|
||||||
|
),
|
||||||
|
dict({
|
||||||
|
'history': list([
|
||||||
|
]),
|
||||||
|
}),
|
||||||
|
),
|
||||||
|
tuple(
|
||||||
|
'().start_chat().send_message_async',
|
||||||
|
tuple(
|
||||||
|
'1st user request',
|
||||||
|
),
|
||||||
|
dict({
|
||||||
|
}),
|
||||||
|
),
|
||||||
|
tuple(
|
||||||
|
'',
|
||||||
|
tuple(
|
||||||
|
),
|
||||||
|
dict({
|
||||||
|
'generation_config': dict({
|
||||||
|
'max_output_tokens': 150,
|
||||||
|
'temperature': 1.0,
|
||||||
|
'top_k': 64,
|
||||||
|
'top_p': 0.95,
|
||||||
|
}),
|
||||||
|
'model_name': 'models/gemini-1.5-pro',
|
||||||
|
'safety_settings': dict({
|
||||||
|
'DANGEROUS': 'BLOCK_MEDIUM_AND_ABOVE',
|
||||||
|
'HARASSMENT': 'BLOCK_MEDIUM_AND_ABOVE',
|
||||||
|
'HATE': 'BLOCK_MEDIUM_AND_ABOVE',
|
||||||
|
'SEXUAL': 'BLOCK_MEDIUM_AND_ABOVE',
|
||||||
|
}),
|
||||||
|
'system_instruction': '''
|
||||||
|
Current time is 05:00:00. Today's date is 2024-05-24.
|
||||||
|
You are a voice assistant for Home Assistant.
|
||||||
|
Answer questions about the world truthfully.
|
||||||
|
Answer in plain text. Keep it simple and to the point.
|
||||||
|
Only if the user wants to control a device, tell them to edit the AI configuration and allow access to Home Assistant.
|
||||||
|
''',
|
||||||
|
'tools': None,
|
||||||
|
}),
|
||||||
|
),
|
||||||
|
tuple(
|
||||||
|
'().start_chat',
|
||||||
|
tuple(
|
||||||
|
),
|
||||||
|
dict({
|
||||||
|
'history': list([
|
||||||
|
dict({
|
||||||
|
'parts': '1st user request',
|
||||||
|
'role': 'user',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'parts': '1st model response',
|
||||||
|
'role': 'model',
|
||||||
|
}),
|
||||||
|
]),
|
||||||
|
}),
|
||||||
|
),
|
||||||
|
tuple(
|
||||||
|
'().start_chat().send_message_async',
|
||||||
|
tuple(
|
||||||
|
'2nd user request',
|
||||||
|
),
|
||||||
|
dict({
|
||||||
|
}),
|
||||||
|
),
|
||||||
|
])
|
||||||
|
# ---
|
||||||
# name: test_default_prompt[config_entry_options0-None]
|
# name: test_default_prompt[config_entry_options0-None]
|
||||||
list([
|
list([
|
||||||
tuple(
|
tuple(
|
||||||
@ -133,6 +239,13 @@
|
|||||||
'HATE': 'BLOCK_MEDIUM_AND_ABOVE',
|
'HATE': 'BLOCK_MEDIUM_AND_ABOVE',
|
||||||
'SEXUAL': 'BLOCK_MEDIUM_AND_ABOVE',
|
'SEXUAL': 'BLOCK_MEDIUM_AND_ABOVE',
|
||||||
}),
|
}),
|
||||||
|
'system_instruction': '''
|
||||||
|
Current time is 05:00:00. Today's date is 2024-05-24.
|
||||||
|
You are a voice assistant for Home Assistant.
|
||||||
|
Answer questions about the world truthfully.
|
||||||
|
Answer in plain text. Keep it simple and to the point.
|
||||||
|
<no_api_prompt>
|
||||||
|
''',
|
||||||
'tools': None,
|
'tools': None,
|
||||||
}),
|
}),
|
||||||
),
|
),
|
||||||
@ -142,19 +255,6 @@
|
|||||||
),
|
),
|
||||||
dict({
|
dict({
|
||||||
'history': list([
|
'history': list([
|
||||||
dict({
|
|
||||||
'parts': '''
|
|
||||||
Current time is 05:00:00. Today's date is 2024-05-24.
|
|
||||||
You are a voice assistant for Home Assistant.
|
|
||||||
Answer in plain text. Keep it simple and to the point.
|
|
||||||
<no_api_prompt>
|
|
||||||
''',
|
|
||||||
'role': 'user',
|
|
||||||
}),
|
|
||||||
dict({
|
|
||||||
'parts': 'Ok',
|
|
||||||
'role': 'model',
|
|
||||||
}),
|
|
||||||
]),
|
]),
|
||||||
}),
|
}),
|
||||||
),
|
),
|
||||||
@ -188,6 +288,13 @@
|
|||||||
'HATE': 'BLOCK_MEDIUM_AND_ABOVE',
|
'HATE': 'BLOCK_MEDIUM_AND_ABOVE',
|
||||||
'SEXUAL': 'BLOCK_MEDIUM_AND_ABOVE',
|
'SEXUAL': 'BLOCK_MEDIUM_AND_ABOVE',
|
||||||
}),
|
}),
|
||||||
|
'system_instruction': '''
|
||||||
|
Current time is 05:00:00. Today's date is 2024-05-24.
|
||||||
|
You are a voice assistant for Home Assistant.
|
||||||
|
Answer questions about the world truthfully.
|
||||||
|
Answer in plain text. Keep it simple and to the point.
|
||||||
|
<no_api_prompt>
|
||||||
|
''',
|
||||||
'tools': None,
|
'tools': None,
|
||||||
}),
|
}),
|
||||||
),
|
),
|
||||||
@ -197,19 +304,6 @@
|
|||||||
),
|
),
|
||||||
dict({
|
dict({
|
||||||
'history': list([
|
'history': list([
|
||||||
dict({
|
|
||||||
'parts': '''
|
|
||||||
Current time is 05:00:00. Today's date is 2024-05-24.
|
|
||||||
You are a voice assistant for Home Assistant.
|
|
||||||
Answer in plain text. Keep it simple and to the point.
|
|
||||||
<no_api_prompt>
|
|
||||||
''',
|
|
||||||
'role': 'user',
|
|
||||||
}),
|
|
||||||
dict({
|
|
||||||
'parts': 'Ok',
|
|
||||||
'role': 'model',
|
|
||||||
}),
|
|
||||||
]),
|
]),
|
||||||
}),
|
}),
|
||||||
),
|
),
|
||||||
@ -243,6 +337,13 @@
|
|||||||
'HATE': 'BLOCK_MEDIUM_AND_ABOVE',
|
'HATE': 'BLOCK_MEDIUM_AND_ABOVE',
|
||||||
'SEXUAL': 'BLOCK_MEDIUM_AND_ABOVE',
|
'SEXUAL': 'BLOCK_MEDIUM_AND_ABOVE',
|
||||||
}),
|
}),
|
||||||
|
'system_instruction': '''
|
||||||
|
Current time is 05:00:00. Today's date is 2024-05-24.
|
||||||
|
You are a voice assistant for Home Assistant.
|
||||||
|
Answer questions about the world truthfully.
|
||||||
|
Answer in plain text. Keep it simple and to the point.
|
||||||
|
<api_prompt>
|
||||||
|
''',
|
||||||
'tools': None,
|
'tools': None,
|
||||||
}),
|
}),
|
||||||
),
|
),
|
||||||
@ -252,19 +353,6 @@
|
|||||||
),
|
),
|
||||||
dict({
|
dict({
|
||||||
'history': list([
|
'history': list([
|
||||||
dict({
|
|
||||||
'parts': '''
|
|
||||||
Current time is 05:00:00. Today's date is 2024-05-24.
|
|
||||||
You are a voice assistant for Home Assistant.
|
|
||||||
Answer in plain text. Keep it simple and to the point.
|
|
||||||
<api_prompt>
|
|
||||||
''',
|
|
||||||
'role': 'user',
|
|
||||||
}),
|
|
||||||
dict({
|
|
||||||
'parts': 'Ok',
|
|
||||||
'role': 'model',
|
|
||||||
}),
|
|
||||||
]),
|
]),
|
||||||
}),
|
}),
|
||||||
),
|
),
|
||||||
@ -298,6 +386,13 @@
|
|||||||
'HATE': 'BLOCK_MEDIUM_AND_ABOVE',
|
'HATE': 'BLOCK_MEDIUM_AND_ABOVE',
|
||||||
'SEXUAL': 'BLOCK_MEDIUM_AND_ABOVE',
|
'SEXUAL': 'BLOCK_MEDIUM_AND_ABOVE',
|
||||||
}),
|
}),
|
||||||
|
'system_instruction': '''
|
||||||
|
Current time is 05:00:00. Today's date is 2024-05-24.
|
||||||
|
You are a voice assistant for Home Assistant.
|
||||||
|
Answer questions about the world truthfully.
|
||||||
|
Answer in plain text. Keep it simple and to the point.
|
||||||
|
<api_prompt>
|
||||||
|
''',
|
||||||
'tools': None,
|
'tools': None,
|
||||||
}),
|
}),
|
||||||
),
|
),
|
||||||
@ -307,19 +402,6 @@
|
|||||||
),
|
),
|
||||||
dict({
|
dict({
|
||||||
'history': list([
|
'history': list([
|
||||||
dict({
|
|
||||||
'parts': '''
|
|
||||||
Current time is 05:00:00. Today's date is 2024-05-24.
|
|
||||||
You are a voice assistant for Home Assistant.
|
|
||||||
Answer in plain text. Keep it simple and to the point.
|
|
||||||
<api_prompt>
|
|
||||||
''',
|
|
||||||
'role': 'user',
|
|
||||||
}),
|
|
||||||
dict({
|
|
||||||
'parts': 'Ok',
|
|
||||||
'role': 'model',
|
|
||||||
}),
|
|
||||||
]),
|
]),
|
||||||
}),
|
}),
|
||||||
),
|
),
|
||||||
|
@ -12,6 +12,9 @@ import voluptuous as vol
|
|||||||
|
|
||||||
from homeassistant.components import conversation
|
from homeassistant.components import conversation
|
||||||
from homeassistant.components.conversation import trace
|
from homeassistant.components.conversation import trace
|
||||||
|
from homeassistant.components.google_generative_ai_conversation.const import (
|
||||||
|
CONF_CHAT_MODEL,
|
||||||
|
)
|
||||||
from homeassistant.components.google_generative_ai_conversation.conversation import (
|
from homeassistant.components.google_generative_ai_conversation.conversation import (
|
||||||
_escape_decode,
|
_escape_decode,
|
||||||
)
|
)
|
||||||
@ -99,13 +102,22 @@ async def test_default_prompt(
|
|||||||
assert mock_get_tools.called == (CONF_LLM_HASS_API in config_entry_options)
|
assert mock_get_tools.called == (CONF_LLM_HASS_API in config_entry_options)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
("model_name", "supports_system_instruction"),
|
||||||
|
[("models/gemini-1.5-pro", True), ("models/gemini-1.0-pro", False)],
|
||||||
|
)
|
||||||
async def test_chat_history(
|
async def test_chat_history(
|
||||||
hass: HomeAssistant,
|
hass: HomeAssistant,
|
||||||
mock_config_entry: MockConfigEntry,
|
mock_config_entry: MockConfigEntry,
|
||||||
mock_init_component,
|
mock_init_component,
|
||||||
|
model_name: str,
|
||||||
|
supports_system_instruction: bool,
|
||||||
snapshot: SnapshotAssertion,
|
snapshot: SnapshotAssertion,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Test that the agent keeps track of the chat history."""
|
"""Test that the agent keeps track of the chat history."""
|
||||||
|
hass.config_entries.async_update_entry(
|
||||||
|
mock_config_entry, options={CONF_CHAT_MODEL: model_name}
|
||||||
|
)
|
||||||
with patch("google.generativeai.GenerativeModel") as mock_model:
|
with patch("google.generativeai.GenerativeModel") as mock_model:
|
||||||
mock_chat = AsyncMock()
|
mock_chat = AsyncMock()
|
||||||
mock_model.return_value.start_chat.return_value = mock_chat
|
mock_model.return_value.start_chat.return_value = mock_chat
|
||||||
@ -115,9 +127,14 @@ async def test_chat_history(
|
|||||||
mock_part.function_call = None
|
mock_part.function_call = None
|
||||||
mock_part.text = "1st model response"
|
mock_part.text = "1st model response"
|
||||||
chat_response.parts = [mock_part]
|
chat_response.parts = [mock_part]
|
||||||
mock_chat.history = [
|
if supports_system_instruction:
|
||||||
{"role": "user", "parts": "prompt"},
|
mock_chat.history = []
|
||||||
{"role": "model", "parts": "Ok"},
|
else:
|
||||||
|
mock_chat.history = [
|
||||||
|
{"role": "user", "parts": "prompt"},
|
||||||
|
{"role": "model", "parts": "Ok"},
|
||||||
|
]
|
||||||
|
mock_chat.history += [
|
||||||
{"role": "user", "parts": "1st user request"},
|
{"role": "user", "parts": "1st user request"},
|
||||||
{"role": "model", "parts": "1st model response"},
|
{"role": "model", "parts": "1st model response"},
|
||||||
]
|
]
|
||||||
@ -256,7 +273,7 @@ async def test_function_call(
|
|||||||
]
|
]
|
||||||
# AGENT_DETAIL event contains the raw prompt passed to the model
|
# AGENT_DETAIL event contains the raw prompt passed to the model
|
||||||
detail_event = trace_events[1]
|
detail_event = trace_events[1]
|
||||||
assert "Answer in plain text" in detail_event["data"]["messages"][0]["parts"]
|
assert "Answer in plain text" in detail_event["data"]["prompt"]
|
||||||
|
|
||||||
|
|
||||||
@patch(
|
@patch(
|
||||||
@ -492,9 +509,9 @@ async def test_template_variables(
|
|||||||
), result
|
), result
|
||||||
assert (
|
assert (
|
||||||
"The user name is Test User."
|
"The user name is Test User."
|
||||||
in mock_model.mock_calls[1][2]["history"][0]["parts"]
|
in mock_model.mock_calls[0][2]["system_instruction"]
|
||||||
)
|
)
|
||||||
assert "The user id is 12345." in mock_model.mock_calls[1][2]["history"][0]["parts"]
|
assert "The user id is 12345." in mock_model.mock_calls[0][2]["system_instruction"]
|
||||||
|
|
||||||
|
|
||||||
async def test_conversation_agent(
|
async def test_conversation_agent(
|
||||||
|
Loading…
x
Reference in New Issue
Block a user