From 0182bfcc81900dacdb7c3ac8daee19f6a08d1d39 Mon Sep 17 00:00:00 2001 From: tronikos Date: Sat, 25 May 2024 04:52:20 -0700 Subject: [PATCH] Google Generative AI: 100% test coverage for conversation (#118112) 100% coverage for conversation --- .../conversation.py | 4 +- .../snapshots/test_conversation.ambr | 110 ++++++++++++++++++ .../test_conversation.py | 105 ++++++++++++++++- 3 files changed, 214 insertions(+), 5 deletions(-) diff --git a/homeassistant/components/google_generative_ai_conversation/conversation.py b/homeassistant/components/google_generative_ai_conversation/conversation.py index 627b28d0966..8a6a761d549 100644 --- a/homeassistant/components/google_generative_ai_conversation/conversation.py +++ b/homeassistant/components/google_generative_ai_conversation/conversation.py @@ -5,7 +5,7 @@ from __future__ import annotations from typing import Any, Literal import google.ai.generativelanguage as glm -from google.api_core.exceptions import ClientError +from google.api_core.exceptions import GoogleAPICallError import google.generativeai as genai import google.generativeai.types as genai_types import voluptuous as vol @@ -258,7 +258,7 @@ class GoogleGenerativeAIConversationEntity( try: chat_response = await chat.send_message_async(chat_request) except ( - ClientError, + GoogleAPICallError, ValueError, genai_types.BlockedPromptException, genai_types.StopCandidateException, diff --git a/tests/components/google_generative_ai_conversation/snapshots/test_conversation.ambr b/tests/components/google_generative_ai_conversation/snapshots/test_conversation.ambr index e1f8141a692..6d37c1d1823 100644 --- a/tests/components/google_generative_ai_conversation/snapshots/test_conversation.ambr +++ b/tests/components/google_generative_ai_conversation/snapshots/test_conversation.ambr @@ -1,4 +1,114 @@ # serializer version: 1 +# name: test_chat_history + list([ + tuple( + '', + tuple( + ), + dict({ + 'generation_config': dict({ + 'max_output_tokens': 150, + 'temperature': 1.0, + 'top_k': 64, + 'top_p': 0.95, + }), + 'model_name': 'models/gemini-1.5-flash-latest', + 'safety_settings': dict({ + 'DANGEROUS': 'BLOCK_MEDIUM_AND_ABOVE', + 'HARASSMENT': 'BLOCK_MEDIUM_AND_ABOVE', + 'HATE': 'BLOCK_MEDIUM_AND_ABOVE', + 'SEXUAL': 'BLOCK_MEDIUM_AND_ABOVE', + }), + 'tools': None, + }), + ), + tuple( + '().start_chat', + tuple( + ), + dict({ + 'history': list([ + dict({ + 'parts': ''' + Answer in plain text. Keep it simple and to the point. + Only if the user wants to control a device, tell them to edit the AI configuration and allow access to Home Assistant. + ''', + 'role': 'user', + }), + dict({ + 'parts': 'Ok', + 'role': 'model', + }), + ]), + }), + ), + tuple( + '().start_chat().send_message_async', + tuple( + '1st user request', + ), + dict({ + }), + ), + tuple( + '', + tuple( + ), + dict({ + 'generation_config': dict({ + 'max_output_tokens': 150, + 'temperature': 1.0, + 'top_k': 64, + 'top_p': 0.95, + }), + 'model_name': 'models/gemini-1.5-flash-latest', + 'safety_settings': dict({ + 'DANGEROUS': 'BLOCK_MEDIUM_AND_ABOVE', + 'HARASSMENT': 'BLOCK_MEDIUM_AND_ABOVE', + 'HATE': 'BLOCK_MEDIUM_AND_ABOVE', + 'SEXUAL': 'BLOCK_MEDIUM_AND_ABOVE', + }), + 'tools': None, + }), + ), + tuple( + '().start_chat', + tuple( + ), + dict({ + 'history': list([ + dict({ + 'parts': ''' + Answer in plain text. Keep it simple and to the point. + Only if the user wants to control a device, tell them to edit the AI configuration and allow access to Home Assistant. + ''', + 'role': 'user', + }), + dict({ + 'parts': 'Ok', + 'role': 'model', + }), + dict({ + 'parts': '1st user request', + 'role': 'user', + }), + dict({ + 'parts': '1st model response', + 'role': 'model', + }), + ]), + }), + ), + tuple( + '().start_chat().send_message_async', + tuple( + '2nd user request', + ), + dict({ + }), + ), + ]) +# --- # name: test_default_prompt[config_entry_options0-None] list([ tuple( diff --git a/tests/components/google_generative_ai_conversation/test_conversation.py b/tests/components/google_generative_ai_conversation/test_conversation.py index af7aebace35..b31d9442a43 100644 --- a/tests/components/google_generative_ai_conversation/test_conversation.py +++ b/tests/components/google_generative_ai_conversation/test_conversation.py @@ -2,7 +2,8 @@ from unittest.mock import AsyncMock, MagicMock, patch -from google.api_core.exceptions import ClientError +from google.api_core.exceptions import GoogleAPICallError +import google.generativeai.types as genai_types import pytest from syrupy.assertion import SnapshotAssertion import voluptuous as vol @@ -150,6 +151,57 @@ async def test_default_prompt( assert mock_get_tools.called == (CONF_LLM_HASS_API in config_entry_options) +async def test_chat_history( + hass: HomeAssistant, + mock_config_entry: MockConfigEntry, + mock_init_component, + snapshot: SnapshotAssertion, +) -> None: + """Test that the agent keeps track of the chat history.""" + with patch("google.generativeai.GenerativeModel") as mock_model: + mock_chat = AsyncMock() + mock_model.return_value.start_chat.return_value = mock_chat + chat_response = MagicMock() + mock_chat.send_message_async.return_value = chat_response + mock_part = MagicMock() + mock_part.function_call = None + chat_response.parts = [mock_part] + chat_response.text = "1st model response" + mock_chat.history = [ + {"role": "user", "parts": "prompt"}, + {"role": "model", "parts": "Ok"}, + {"role": "user", "parts": "1st user request"}, + {"role": "model", "parts": "1st model response"}, + ] + result = await conversation.async_converse( + hass, + "1st user request", + None, + Context(), + agent_id=mock_config_entry.entry_id, + ) + assert result.response.response_type == intent.IntentResponseType.ACTION_DONE + assert ( + result.response.as_dict()["speech"]["plain"]["speech"] + == "1st model response" + ) + chat_response.text = "2nd model response" + result = await conversation.async_converse( + hass, + "2nd user request", + result.conversation_id, + Context(), + agent_id=mock_config_entry.entry_id, + ) + assert result.response.response_type == intent.IntentResponseType.ACTION_DONE + assert ( + result.response.as_dict()["speech"]["plain"]["speech"] + == "2nd model response" + ) + + assert [tuple(mock_call) for mock_call in mock_model.mock_calls] == snapshot + + @patch( "homeassistant.components.google_generative_ai_conversation.conversation.llm.AssistAPI.async_get_tools" ) @@ -325,7 +377,7 @@ async def test_error_handling( with patch("google.generativeai.GenerativeModel") as mock_model: mock_chat = AsyncMock() mock_model.return_value.start_chat.return_value = mock_chat - mock_chat.send_message_async.side_effect = ClientError("some error") + mock_chat.send_message_async.side_effect = GoogleAPICallError("some error") result = await conversation.async_converse( hass, "hello", None, Context(), agent_id=mock_config_entry.entry_id ) @@ -340,7 +392,28 @@ async def test_error_handling( async def test_blocked_response( hass: HomeAssistant, mock_config_entry: MockConfigEntry, mock_init_component ) -> None: - """Test response was blocked.""" + """Test blocked response.""" + with patch("google.generativeai.GenerativeModel") as mock_model: + mock_chat = AsyncMock() + mock_model.return_value.start_chat.return_value = mock_chat + mock_chat.send_message_async.side_effect = genai_types.StopCandidateException( + "finish_reason: SAFETY\n" + ) + result = await conversation.async_converse( + hass, "hello", None, Context(), agent_id=mock_config_entry.entry_id + ) + + assert result.response.response_type == intent.IntentResponseType.ERROR, result + assert result.response.error_code == "unknown", result + assert result.response.as_dict()["speech"]["plain"]["speech"] == ( + "The message got blocked by your safety settings" + ) + + +async def test_empty_response( + hass: HomeAssistant, mock_config_entry: MockConfigEntry, mock_init_component +) -> None: + """Test empty response.""" with patch("google.generativeai.GenerativeModel") as mock_model: mock_chat = AsyncMock() mock_model.return_value.start_chat.return_value = mock_chat @@ -358,6 +431,32 @@ async def test_blocked_response( ) +async def test_invalid_llm_api( + hass: HomeAssistant, + mock_config_entry: MockConfigEntry, + mock_init_component, +) -> None: + """Test handling of invalid llm api.""" + hass.config_entries.async_update_entry( + mock_config_entry, + options={**mock_config_entry.options, CONF_LLM_HASS_API: "invalid_llm_api"}, + ) + + result = await conversation.async_converse( + hass, + "hello", + None, + Context(), + agent_id=mock_config_entry.entry_id, + ) + + assert result.response.response_type == intent.IntentResponseType.ERROR, result + assert result.response.error_code == "unknown", result + assert result.response.as_dict()["speech"]["plain"]["speech"] == ( + "Error preparing LLM API: API invalid_llm_api not found" + ) + + async def test_template_error( hass: HomeAssistant, mock_config_entry: MockConfigEntry ) -> None: