mirror of
https://github.com/home-assistant/core.git
synced 2025-07-16 17:57:11 +00:00
Google Generative AI: 100% test coverage for conversation (#118112)
100% coverage for conversation
This commit is contained in:
parent
8fbe39f2a7
commit
0182bfcc81
@ -5,7 +5,7 @@ from __future__ import annotations
|
|||||||
from typing import Any, Literal
|
from typing import Any, Literal
|
||||||
|
|
||||||
import google.ai.generativelanguage as glm
|
import google.ai.generativelanguage as glm
|
||||||
from google.api_core.exceptions import ClientError
|
from google.api_core.exceptions import GoogleAPICallError
|
||||||
import google.generativeai as genai
|
import google.generativeai as genai
|
||||||
import google.generativeai.types as genai_types
|
import google.generativeai.types as genai_types
|
||||||
import voluptuous as vol
|
import voluptuous as vol
|
||||||
@ -258,7 +258,7 @@ class GoogleGenerativeAIConversationEntity(
|
|||||||
try:
|
try:
|
||||||
chat_response = await chat.send_message_async(chat_request)
|
chat_response = await chat.send_message_async(chat_request)
|
||||||
except (
|
except (
|
||||||
ClientError,
|
GoogleAPICallError,
|
||||||
ValueError,
|
ValueError,
|
||||||
genai_types.BlockedPromptException,
|
genai_types.BlockedPromptException,
|
||||||
genai_types.StopCandidateException,
|
genai_types.StopCandidateException,
|
||||||
|
@ -1,4 +1,114 @@
|
|||||||
# serializer version: 1
|
# serializer version: 1
|
||||||
|
# name: test_chat_history
|
||||||
|
list([
|
||||||
|
tuple(
|
||||||
|
'',
|
||||||
|
tuple(
|
||||||
|
),
|
||||||
|
dict({
|
||||||
|
'generation_config': dict({
|
||||||
|
'max_output_tokens': 150,
|
||||||
|
'temperature': 1.0,
|
||||||
|
'top_k': 64,
|
||||||
|
'top_p': 0.95,
|
||||||
|
}),
|
||||||
|
'model_name': 'models/gemini-1.5-flash-latest',
|
||||||
|
'safety_settings': dict({
|
||||||
|
'DANGEROUS': 'BLOCK_MEDIUM_AND_ABOVE',
|
||||||
|
'HARASSMENT': 'BLOCK_MEDIUM_AND_ABOVE',
|
||||||
|
'HATE': 'BLOCK_MEDIUM_AND_ABOVE',
|
||||||
|
'SEXUAL': 'BLOCK_MEDIUM_AND_ABOVE',
|
||||||
|
}),
|
||||||
|
'tools': None,
|
||||||
|
}),
|
||||||
|
),
|
||||||
|
tuple(
|
||||||
|
'().start_chat',
|
||||||
|
tuple(
|
||||||
|
),
|
||||||
|
dict({
|
||||||
|
'history': list([
|
||||||
|
dict({
|
||||||
|
'parts': '''
|
||||||
|
Answer in plain text. Keep it simple and to the point.
|
||||||
|
Only if the user wants to control a device, tell them to edit the AI configuration and allow access to Home Assistant.
|
||||||
|
''',
|
||||||
|
'role': 'user',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'parts': 'Ok',
|
||||||
|
'role': 'model',
|
||||||
|
}),
|
||||||
|
]),
|
||||||
|
}),
|
||||||
|
),
|
||||||
|
tuple(
|
||||||
|
'().start_chat().send_message_async',
|
||||||
|
tuple(
|
||||||
|
'1st user request',
|
||||||
|
),
|
||||||
|
dict({
|
||||||
|
}),
|
||||||
|
),
|
||||||
|
tuple(
|
||||||
|
'',
|
||||||
|
tuple(
|
||||||
|
),
|
||||||
|
dict({
|
||||||
|
'generation_config': dict({
|
||||||
|
'max_output_tokens': 150,
|
||||||
|
'temperature': 1.0,
|
||||||
|
'top_k': 64,
|
||||||
|
'top_p': 0.95,
|
||||||
|
}),
|
||||||
|
'model_name': 'models/gemini-1.5-flash-latest',
|
||||||
|
'safety_settings': dict({
|
||||||
|
'DANGEROUS': 'BLOCK_MEDIUM_AND_ABOVE',
|
||||||
|
'HARASSMENT': 'BLOCK_MEDIUM_AND_ABOVE',
|
||||||
|
'HATE': 'BLOCK_MEDIUM_AND_ABOVE',
|
||||||
|
'SEXUAL': 'BLOCK_MEDIUM_AND_ABOVE',
|
||||||
|
}),
|
||||||
|
'tools': None,
|
||||||
|
}),
|
||||||
|
),
|
||||||
|
tuple(
|
||||||
|
'().start_chat',
|
||||||
|
tuple(
|
||||||
|
),
|
||||||
|
dict({
|
||||||
|
'history': list([
|
||||||
|
dict({
|
||||||
|
'parts': '''
|
||||||
|
Answer in plain text. Keep it simple and to the point.
|
||||||
|
Only if the user wants to control a device, tell them to edit the AI configuration and allow access to Home Assistant.
|
||||||
|
''',
|
||||||
|
'role': 'user',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'parts': 'Ok',
|
||||||
|
'role': 'model',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'parts': '1st user request',
|
||||||
|
'role': 'user',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'parts': '1st model response',
|
||||||
|
'role': 'model',
|
||||||
|
}),
|
||||||
|
]),
|
||||||
|
}),
|
||||||
|
),
|
||||||
|
tuple(
|
||||||
|
'().start_chat().send_message_async',
|
||||||
|
tuple(
|
||||||
|
'2nd user request',
|
||||||
|
),
|
||||||
|
dict({
|
||||||
|
}),
|
||||||
|
),
|
||||||
|
])
|
||||||
|
# ---
|
||||||
# name: test_default_prompt[config_entry_options0-None]
|
# name: test_default_prompt[config_entry_options0-None]
|
||||||
list([
|
list([
|
||||||
tuple(
|
tuple(
|
||||||
|
@ -2,7 +2,8 @@
|
|||||||
|
|
||||||
from unittest.mock import AsyncMock, MagicMock, patch
|
from unittest.mock import AsyncMock, MagicMock, patch
|
||||||
|
|
||||||
from google.api_core.exceptions import ClientError
|
from google.api_core.exceptions import GoogleAPICallError
|
||||||
|
import google.generativeai.types as genai_types
|
||||||
import pytest
|
import pytest
|
||||||
from syrupy.assertion import SnapshotAssertion
|
from syrupy.assertion import SnapshotAssertion
|
||||||
import voluptuous as vol
|
import voluptuous as vol
|
||||||
@ -150,6 +151,57 @@ async def test_default_prompt(
|
|||||||
assert mock_get_tools.called == (CONF_LLM_HASS_API in config_entry_options)
|
assert mock_get_tools.called == (CONF_LLM_HASS_API in config_entry_options)
|
||||||
|
|
||||||
|
|
||||||
|
async def test_chat_history(
|
||||||
|
hass: HomeAssistant,
|
||||||
|
mock_config_entry: MockConfigEntry,
|
||||||
|
mock_init_component,
|
||||||
|
snapshot: SnapshotAssertion,
|
||||||
|
) -> None:
|
||||||
|
"""Test that the agent keeps track of the chat history."""
|
||||||
|
with patch("google.generativeai.GenerativeModel") as mock_model:
|
||||||
|
mock_chat = AsyncMock()
|
||||||
|
mock_model.return_value.start_chat.return_value = mock_chat
|
||||||
|
chat_response = MagicMock()
|
||||||
|
mock_chat.send_message_async.return_value = chat_response
|
||||||
|
mock_part = MagicMock()
|
||||||
|
mock_part.function_call = None
|
||||||
|
chat_response.parts = [mock_part]
|
||||||
|
chat_response.text = "1st model response"
|
||||||
|
mock_chat.history = [
|
||||||
|
{"role": "user", "parts": "prompt"},
|
||||||
|
{"role": "model", "parts": "Ok"},
|
||||||
|
{"role": "user", "parts": "1st user request"},
|
||||||
|
{"role": "model", "parts": "1st model response"},
|
||||||
|
]
|
||||||
|
result = await conversation.async_converse(
|
||||||
|
hass,
|
||||||
|
"1st user request",
|
||||||
|
None,
|
||||||
|
Context(),
|
||||||
|
agent_id=mock_config_entry.entry_id,
|
||||||
|
)
|
||||||
|
assert result.response.response_type == intent.IntentResponseType.ACTION_DONE
|
||||||
|
assert (
|
||||||
|
result.response.as_dict()["speech"]["plain"]["speech"]
|
||||||
|
== "1st model response"
|
||||||
|
)
|
||||||
|
chat_response.text = "2nd model response"
|
||||||
|
result = await conversation.async_converse(
|
||||||
|
hass,
|
||||||
|
"2nd user request",
|
||||||
|
result.conversation_id,
|
||||||
|
Context(),
|
||||||
|
agent_id=mock_config_entry.entry_id,
|
||||||
|
)
|
||||||
|
assert result.response.response_type == intent.IntentResponseType.ACTION_DONE
|
||||||
|
assert (
|
||||||
|
result.response.as_dict()["speech"]["plain"]["speech"]
|
||||||
|
== "2nd model response"
|
||||||
|
)
|
||||||
|
|
||||||
|
assert [tuple(mock_call) for mock_call in mock_model.mock_calls] == snapshot
|
||||||
|
|
||||||
|
|
||||||
@patch(
|
@patch(
|
||||||
"homeassistant.components.google_generative_ai_conversation.conversation.llm.AssistAPI.async_get_tools"
|
"homeassistant.components.google_generative_ai_conversation.conversation.llm.AssistAPI.async_get_tools"
|
||||||
)
|
)
|
||||||
@ -325,7 +377,7 @@ async def test_error_handling(
|
|||||||
with patch("google.generativeai.GenerativeModel") as mock_model:
|
with patch("google.generativeai.GenerativeModel") as mock_model:
|
||||||
mock_chat = AsyncMock()
|
mock_chat = AsyncMock()
|
||||||
mock_model.return_value.start_chat.return_value = mock_chat
|
mock_model.return_value.start_chat.return_value = mock_chat
|
||||||
mock_chat.send_message_async.side_effect = ClientError("some error")
|
mock_chat.send_message_async.side_effect = GoogleAPICallError("some error")
|
||||||
result = await conversation.async_converse(
|
result = await conversation.async_converse(
|
||||||
hass, "hello", None, Context(), agent_id=mock_config_entry.entry_id
|
hass, "hello", None, Context(), agent_id=mock_config_entry.entry_id
|
||||||
)
|
)
|
||||||
@ -340,7 +392,28 @@ async def test_error_handling(
|
|||||||
async def test_blocked_response(
|
async def test_blocked_response(
|
||||||
hass: HomeAssistant, mock_config_entry: MockConfigEntry, mock_init_component
|
hass: HomeAssistant, mock_config_entry: MockConfigEntry, mock_init_component
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Test response was blocked."""
|
"""Test blocked response."""
|
||||||
|
with patch("google.generativeai.GenerativeModel") as mock_model:
|
||||||
|
mock_chat = AsyncMock()
|
||||||
|
mock_model.return_value.start_chat.return_value = mock_chat
|
||||||
|
mock_chat.send_message_async.side_effect = genai_types.StopCandidateException(
|
||||||
|
"finish_reason: SAFETY\n"
|
||||||
|
)
|
||||||
|
result = await conversation.async_converse(
|
||||||
|
hass, "hello", None, Context(), agent_id=mock_config_entry.entry_id
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result.response.response_type == intent.IntentResponseType.ERROR, result
|
||||||
|
assert result.response.error_code == "unknown", result
|
||||||
|
assert result.response.as_dict()["speech"]["plain"]["speech"] == (
|
||||||
|
"The message got blocked by your safety settings"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def test_empty_response(
|
||||||
|
hass: HomeAssistant, mock_config_entry: MockConfigEntry, mock_init_component
|
||||||
|
) -> None:
|
||||||
|
"""Test empty response."""
|
||||||
with patch("google.generativeai.GenerativeModel") as mock_model:
|
with patch("google.generativeai.GenerativeModel") as mock_model:
|
||||||
mock_chat = AsyncMock()
|
mock_chat = AsyncMock()
|
||||||
mock_model.return_value.start_chat.return_value = mock_chat
|
mock_model.return_value.start_chat.return_value = mock_chat
|
||||||
@ -358,6 +431,32 @@ async def test_blocked_response(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def test_invalid_llm_api(
|
||||||
|
hass: HomeAssistant,
|
||||||
|
mock_config_entry: MockConfigEntry,
|
||||||
|
mock_init_component,
|
||||||
|
) -> None:
|
||||||
|
"""Test handling of invalid llm api."""
|
||||||
|
hass.config_entries.async_update_entry(
|
||||||
|
mock_config_entry,
|
||||||
|
options={**mock_config_entry.options, CONF_LLM_HASS_API: "invalid_llm_api"},
|
||||||
|
)
|
||||||
|
|
||||||
|
result = await conversation.async_converse(
|
||||||
|
hass,
|
||||||
|
"hello",
|
||||||
|
None,
|
||||||
|
Context(),
|
||||||
|
agent_id=mock_config_entry.entry_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result.response.response_type == intent.IntentResponseType.ERROR, result
|
||||||
|
assert result.response.error_code == "unknown", result
|
||||||
|
assert result.response.as_dict()["speech"]["plain"]["speech"] == (
|
||||||
|
"Error preparing LLM API: API invalid_llm_api not found"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
async def test_template_error(
|
async def test_template_error(
|
||||||
hass: HomeAssistant, mock_config_entry: MockConfigEntry
|
hass: HomeAssistant, mock_config_entry: MockConfigEntry
|
||||||
) -> None:
|
) -> None:
|
||||||
|
Loading…
x
Reference in New Issue
Block a user