Google Generative AI: Add a service for prompts consisting of text and images using Gemini Pro Vision (#105789)

* Bump google-generativeai to 0.3.1

* Migrate to the new API and default to gemini-pro

* Add max output tokens option

* Add generate_content service

* Add  tests

* additional checks

* async read_bytes

* Add tests for all errors
This commit is contained in:
tronikos 2024-01-07 13:21:27 -08:00 committed by GitHub
parent fd52172c33
commit 810c6ea5ae
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 450 additions and 74 deletions

View File

@ -3,44 +3,122 @@ from __future__ import annotations
from functools import partial from functools import partial
import logging import logging
import mimetypes
from pathlib import Path
from typing import Literal from typing import Literal
from google.api_core.exceptions import ClientError from google.api_core.exceptions import ClientError
import google.generativeai as palm import google.generativeai as genai
from google.generativeai.types.discuss_types import ChatResponse import google.generativeai.types as genai_types
import voluptuous as vol
from homeassistant.components import conversation from homeassistant.components import conversation
from homeassistant.config_entries import ConfigEntry from homeassistant.config_entries import ConfigEntry
from homeassistant.const import CONF_API_KEY, MATCH_ALL from homeassistant.const import CONF_API_KEY, MATCH_ALL
from homeassistant.core import HomeAssistant from homeassistant.core import (
from homeassistant.exceptions import ConfigEntryNotReady, TemplateError HomeAssistant,
from homeassistant.helpers import intent, template ServiceCall,
ServiceResponse,
SupportsResponse,
)
from homeassistant.exceptions import (
ConfigEntryNotReady,
HomeAssistantError,
TemplateError,
)
from homeassistant.helpers import config_validation as cv, intent, template
from homeassistant.helpers.typing import ConfigType
from homeassistant.util import ulid from homeassistant.util import ulid
from .const import ( from .const import (
CONF_CHAT_MODEL, CONF_CHAT_MODEL,
CONF_MAX_TOKENS,
CONF_PROMPT, CONF_PROMPT,
CONF_TEMPERATURE, CONF_TEMPERATURE,
CONF_TOP_K, CONF_TOP_K,
CONF_TOP_P, CONF_TOP_P,
DEFAULT_CHAT_MODEL, DEFAULT_CHAT_MODEL,
DEFAULT_MAX_TOKENS,
DEFAULT_PROMPT, DEFAULT_PROMPT,
DEFAULT_TEMPERATURE, DEFAULT_TEMPERATURE,
DEFAULT_TOP_K, DEFAULT_TOP_K,
DEFAULT_TOP_P, DEFAULT_TOP_P,
DOMAIN,
) )
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
SERVICE_GENERATE_CONTENT = "generate_content"
CONF_IMAGE_FILENAME = "image_filename"
CONFIG_SCHEMA = cv.config_entry_only_config_schema(DOMAIN)
async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
"""Set up Google Generative AI Conversation."""
async def generate_content(call: ServiceCall) -> ServiceResponse:
"""Generate content from text and optionally images."""
prompt_parts = [call.data[CONF_PROMPT]]
image_filenames = call.data[CONF_IMAGE_FILENAME]
for image_filename in image_filenames:
if not hass.config.is_allowed_path(image_filename):
raise HomeAssistantError(
f"Cannot read `{image_filename}`, no access to path; `allowlist_external_dirs` may need to be adjusted in `configuration.yaml`"
)
if not Path(image_filename).exists():
raise HomeAssistantError(f"`{image_filename}` does not exist")
mime_type, _ = mimetypes.guess_type(image_filename)
if mime_type is None or not mime_type.startswith("image"):
raise HomeAssistantError(f"`{image_filename}` is not an image")
prompt_parts.append(
{
"mime_type": mime_type,
"data": await hass.async_add_executor_job(
Path(image_filename).read_bytes
),
}
)
model_name = "gemini-pro-vision" if image_filenames else "gemini-pro"
model = genai.GenerativeModel(model_name=model_name)
try:
response = await model.generate_content_async(prompt_parts)
except (
ClientError,
ValueError,
genai_types.BlockedPromptException,
genai_types.StopCandidateException,
) as err:
raise HomeAssistantError(f"Error generating content: {err}") from err
return {"text": response.text}
hass.services.async_register(
DOMAIN,
SERVICE_GENERATE_CONTENT,
generate_content,
schema=vol.Schema(
{
vol.Required(CONF_PROMPT): cv.string,
vol.Optional(CONF_IMAGE_FILENAME, default=[]): vol.All(
cv.ensure_list, [cv.string]
),
}
),
supports_response=SupportsResponse.ONLY,
)
return True
async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool: async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
"""Set up Google Generative AI Conversation from a config entry.""" """Set up Google Generative AI Conversation from a config entry."""
palm.configure(api_key=entry.data[CONF_API_KEY]) genai.configure(api_key=entry.data[CONF_API_KEY])
try: try:
await hass.async_add_executor_job( await hass.async_add_executor_job(
partial( partial(
palm.get_model, entry.options.get(CONF_CHAT_MODEL, DEFAULT_CHAT_MODEL) genai.get_model, entry.options.get(CONF_CHAT_MODEL, DEFAULT_CHAT_MODEL)
) )
) )
except ClientError as err: except ClientError as err:
@ -55,7 +133,7 @@ async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
async def async_unload_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool: async def async_unload_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
"""Unload GoogleGenerativeAI.""" """Unload GoogleGenerativeAI."""
palm.configure(api_key=None) genai.configure(api_key=None)
conversation.async_unset_agent(hass, entry) conversation.async_unset_agent(hass, entry)
return True return True
@ -67,7 +145,7 @@ class GoogleGenerativeAIAgent(conversation.AbstractConversationAgent):
"""Initialize the agent.""" """Initialize the agent."""
self.hass = hass self.hass = hass
self.entry = entry self.entry = entry
self.history: dict[str, list[dict]] = {} self.history: dict[str, list[genai_types.ContentType]] = {}
@property @property
def supported_languages(self) -> list[str] | Literal["*"]: def supported_languages(self) -> list[str] | Literal["*"]:
@ -79,17 +157,27 @@ class GoogleGenerativeAIAgent(conversation.AbstractConversationAgent):
) -> conversation.ConversationResult: ) -> conversation.ConversationResult:
"""Process a sentence.""" """Process a sentence."""
raw_prompt = self.entry.options.get(CONF_PROMPT, DEFAULT_PROMPT) raw_prompt = self.entry.options.get(CONF_PROMPT, DEFAULT_PROMPT)
model = self.entry.options.get(CONF_CHAT_MODEL, DEFAULT_CHAT_MODEL) model = genai.GenerativeModel(
temperature = self.entry.options.get(CONF_TEMPERATURE, DEFAULT_TEMPERATURE) model_name=self.entry.options.get(CONF_CHAT_MODEL, DEFAULT_CHAT_MODEL),
top_p = self.entry.options.get(CONF_TOP_P, DEFAULT_TOP_P) generation_config={
top_k = self.entry.options.get(CONF_TOP_K, DEFAULT_TOP_K) "temperature": self.entry.options.get(
CONF_TEMPERATURE, DEFAULT_TEMPERATURE
),
"top_p": self.entry.options.get(CONF_TOP_P, DEFAULT_TOP_P),
"top_k": self.entry.options.get(CONF_TOP_K, DEFAULT_TOP_K),
"max_output_tokens": self.entry.options.get(
CONF_MAX_TOKENS, DEFAULT_MAX_TOKENS
),
},
)
_LOGGER.debug("Model: %s", model)
if user_input.conversation_id in self.history: if user_input.conversation_id in self.history:
conversation_id = user_input.conversation_id conversation_id = user_input.conversation_id
messages = self.history[conversation_id] messages = self.history[conversation_id]
else: else:
conversation_id = ulid.ulid_now() conversation_id = ulid.ulid_now()
messages = [] messages = [{}, {}]
try: try:
prompt = self._async_generate_prompt(raw_prompt) prompt = self._async_generate_prompt(raw_prompt)
@ -104,20 +192,21 @@ class GoogleGenerativeAIAgent(conversation.AbstractConversationAgent):
response=intent_response, conversation_id=conversation_id response=intent_response, conversation_id=conversation_id
) )
messages.append({"author": "0", "content": user_input.text}) messages[0] = {"role": "user", "parts": prompt}
messages[1] = {"role": "model", "parts": "Ok"}
_LOGGER.debug("Prompt for %s: %s", model, messages) _LOGGER.debug("Input: '%s' with history: %s", user_input.text, messages)
chat = model.start_chat(history=messages)
try: try:
chat_response: ChatResponse = await palm.chat_async( chat_response = await chat.send_message_async(user_input.text)
model=model, except (
context=prompt, ClientError,
messages=messages, ValueError,
temperature=temperature, genai_types.BlockedPromptException,
top_p=top_p, genai_types.StopCandidateException,
top_k=top_k, ) as err:
) _LOGGER.error("Error sending message: %s", err)
except ClientError as err:
intent_response = intent.IntentResponse(language=user_input.language) intent_response = intent.IntentResponse(language=user_input.language)
intent_response.async_set_error( intent_response.async_set_error(
intent.IntentResponseErrorCode.UNKNOWN, intent.IntentResponseErrorCode.UNKNOWN,
@ -127,14 +216,11 @@ class GoogleGenerativeAIAgent(conversation.AbstractConversationAgent):
response=intent_response, conversation_id=conversation_id response=intent_response, conversation_id=conversation_id
) )
_LOGGER.debug("Response %s", chat_response) _LOGGER.debug("Response: %s", chat_response.parts)
# For some queries the response is empty. In that case don't update history to avoid self.history[conversation_id] = chat.history
# "google.generativeai.types.discuss_types.AuthorError: Authors are not strictly alternating"
if chat_response.last:
self.history[conversation_id] = chat_response.messages
intent_response = intent.IntentResponse(language=user_input.language) intent_response = intent.IntentResponse(language=user_input.language)
intent_response.async_set_speech(chat_response.last) intent_response.async_set_speech(chat_response.text)
return conversation.ConversationResult( return conversation.ConversationResult(
response=intent_response, conversation_id=conversation_id response=intent_response, conversation_id=conversation_id
) )

View File

@ -8,7 +8,7 @@ from types import MappingProxyType
from typing import Any from typing import Any
from google.api_core.exceptions import ClientError from google.api_core.exceptions import ClientError
import google.generativeai as palm import google.generativeai as genai
import voluptuous as vol import voluptuous as vol
from homeassistant import config_entries from homeassistant import config_entries
@ -23,11 +23,13 @@ from homeassistant.helpers.selector import (
from .const import ( from .const import (
CONF_CHAT_MODEL, CONF_CHAT_MODEL,
CONF_MAX_TOKENS,
CONF_PROMPT, CONF_PROMPT,
CONF_TEMPERATURE, CONF_TEMPERATURE,
CONF_TOP_K, CONF_TOP_K,
CONF_TOP_P, CONF_TOP_P,
DEFAULT_CHAT_MODEL, DEFAULT_CHAT_MODEL,
DEFAULT_MAX_TOKENS,
DEFAULT_PROMPT, DEFAULT_PROMPT,
DEFAULT_TEMPERATURE, DEFAULT_TEMPERATURE,
DEFAULT_TOP_K, DEFAULT_TOP_K,
@ -50,6 +52,7 @@ DEFAULT_OPTIONS = types.MappingProxyType(
CONF_TEMPERATURE: DEFAULT_TEMPERATURE, CONF_TEMPERATURE: DEFAULT_TEMPERATURE,
CONF_TOP_P: DEFAULT_TOP_P, CONF_TOP_P: DEFAULT_TOP_P,
CONF_TOP_K: DEFAULT_TOP_K, CONF_TOP_K: DEFAULT_TOP_K,
CONF_MAX_TOKENS: DEFAULT_MAX_TOKENS,
} }
) )
@ -59,8 +62,8 @@ async def validate_input(hass: HomeAssistant, data: dict[str, Any]) -> None:
Data has the keys from STEP_USER_DATA_SCHEMA with values provided by the user. Data has the keys from STEP_USER_DATA_SCHEMA with values provided by the user.
""" """
palm.configure(api_key=data[CONF_API_KEY]) genai.configure(api_key=data[CONF_API_KEY])
await hass.async_add_executor_job(partial(palm.list_models)) await hass.async_add_executor_job(partial(genai.list_models))
class ConfigFlow(config_entries.ConfigFlow, domain=DOMAIN): class ConfigFlow(config_entries.ConfigFlow, domain=DOMAIN):
@ -162,4 +165,9 @@ def google_generative_ai_config_option_schema(
description={"suggested_value": options[CONF_TOP_K]}, description={"suggested_value": options[CONF_TOP_K]},
default=DEFAULT_TOP_K, default=DEFAULT_TOP_K,
): int, ): int,
vol.Optional(
CONF_MAX_TOKENS,
description={"suggested_value": options[CONF_MAX_TOKENS]},
default=DEFAULT_MAX_TOKENS,
): int,
} }

View File

@ -24,10 +24,12 @@ Answer the user's questions about the world truthfully.
If the user wants to control a device, reject the request and suggest using the Home Assistant app. If the user wants to control a device, reject the request and suggest using the Home Assistant app.
""" """
CONF_CHAT_MODEL = "chat_model" CONF_CHAT_MODEL = "chat_model"
DEFAULT_CHAT_MODEL = "models/chat-bison-001" DEFAULT_CHAT_MODEL = "models/gemini-pro"
CONF_TEMPERATURE = "temperature" CONF_TEMPERATURE = "temperature"
DEFAULT_TEMPERATURE = 0.25 DEFAULT_TEMPERATURE = 0.9
CONF_TOP_P = "top_p" CONF_TOP_P = "top_p"
DEFAULT_TOP_P = 0.95 DEFAULT_TOP_P = 1.0
CONF_TOP_K = "top_k" CONF_TOP_K = "top_k"
DEFAULT_TOP_K = 40 DEFAULT_TOP_K = 1
CONF_MAX_TOKENS = "max_tokens"
DEFAULT_MAX_TOKENS = 150

View File

@ -0,0 +1,11 @@
generate_content:
fields:
prompt:
required: true
selector:
text:
multiline: true
image_filename:
required: false
selector:
object:

View File

@ -21,7 +21,26 @@
"model": "[%key:common::generic::model%]", "model": "[%key:common::generic::model%]",
"temperature": "Temperature", "temperature": "Temperature",
"top_p": "Top P", "top_p": "Top P",
"top_k": "Top K" "top_k": "Top K",
"max_tokens": "Maximum tokens to return in response"
}
}
}
},
"services": {
"generate_content": {
"name": "Generate content",
"description": "Generate content from a prompt consisting of text and optionally images",
"fields": {
"prompt": {
"name": "Prompt",
"description": "The prompt",
"example": "Describe what you see in these images:"
},
"image_filename": {
"name": "Image filename",
"description": "Images",
"example": "/config/www/image.jpg"
} }
} }
} }

View File

@ -1,33 +1,109 @@
# serializer version: 1 # serializer version: 1
# name: test_default_prompt # name: test_default_prompt
dict({ list([
'context': ''' tuple(
This smart home is controlled by Home Assistant. '',
tuple(
An overview of the areas and the devices in this smart home: ),
Test Area:
- Test Device (Test Model)
Test Area 2:
- Test Device 2
- Test Device 3 (Test Model 3A)
- Test Device 4
- 1 (3)
Answer the user's questions about the world truthfully.
If the user wants to control a device, reject the request and suggest using the Home Assistant app.
''',
'messages': list([
dict({ dict({
'author': '0', 'generation_config': dict({
'content': 'hello', 'max_output_tokens': 150,
'temperature': 0.9,
'top_k': 1,
'top_p': 1.0,
}),
'model_name': 'models/gemini-pro',
}), }),
]), ),
'model': 'models/chat-bison-001', tuple(
'temperature': 0.25, '().start_chat',
'top_k': 40, tuple(
'top_p': 0.95, ),
}) dict({
'history': list([
dict({
'parts': '''
This smart home is controlled by Home Assistant.
An overview of the areas and the devices in this smart home:
Test Area:
- Test Device (Test Model)
Test Area 2:
- Test Device 2
- Test Device 3 (Test Model 3A)
- Test Device 4
- 1 (3)
Answer the user's questions about the world truthfully.
If the user wants to control a device, reject the request and suggest using the Home Assistant app.
''',
'role': 'user',
}),
dict({
'parts': 'Ok',
'role': 'model',
}),
]),
}),
),
tuple(
'().start_chat().send_message_async',
tuple(
'hello',
),
dict({
}),
),
])
# ---
# name: test_generate_content_service_with_image
list([
tuple(
'',
tuple(
),
dict({
'model_name': 'gemini-pro-vision',
}),
),
tuple(
'().generate_content_async',
tuple(
list([
'Describe this image from my doorbell camera',
dict({
'data': b'image bytes',
'mime_type': 'image/jpeg',
}),
]),
),
dict({
}),
),
])
# ---
# name: test_generate_content_service_without_images
list([
tuple(
'',
tuple(
),
dict({
'model_name': 'gemini-pro',
}),
),
tuple(
'().generate_content_async',
tuple(
list([
'Write an opening speech for a Home Assistant release party',
]),
),
dict({
}),
),
])
# --- # ---

View File

@ -8,9 +8,11 @@ import pytest
from homeassistant import config_entries from homeassistant import config_entries
from homeassistant.components.google_generative_ai_conversation.const import ( from homeassistant.components.google_generative_ai_conversation.const import (
CONF_CHAT_MODEL, CONF_CHAT_MODEL,
CONF_MAX_TOKENS,
CONF_TOP_K, CONF_TOP_K,
CONF_TOP_P, CONF_TOP_P,
DEFAULT_CHAT_MODEL, DEFAULT_CHAT_MODEL,
DEFAULT_MAX_TOKENS,
DEFAULT_TOP_K, DEFAULT_TOP_K,
DEFAULT_TOP_P, DEFAULT_TOP_P,
DOMAIN, DOMAIN,
@ -37,7 +39,7 @@ async def test_form(hass: HomeAssistant) -> None:
assert result["errors"] is None assert result["errors"] is None
with patch( with patch(
"homeassistant.components.google_generative_ai_conversation.config_flow.palm.list_models", "homeassistant.components.google_generative_ai_conversation.config_flow.genai.list_models",
), patch( ), patch(
"homeassistant.components.google_generative_ai_conversation.async_setup_entry", "homeassistant.components.google_generative_ai_conversation.async_setup_entry",
return_value=True, return_value=True,
@ -78,6 +80,7 @@ async def test_options(
assert options["data"][CONF_CHAT_MODEL] == DEFAULT_CHAT_MODEL assert options["data"][CONF_CHAT_MODEL] == DEFAULT_CHAT_MODEL
assert options["data"][CONF_TOP_P] == DEFAULT_TOP_P assert options["data"][CONF_TOP_P] == DEFAULT_TOP_P
assert options["data"][CONF_TOP_K] == DEFAULT_TOP_K assert options["data"][CONF_TOP_K] == DEFAULT_TOP_K
assert options["data"][CONF_MAX_TOKENS] == DEFAULT_MAX_TOKENS
@pytest.mark.parametrize( @pytest.mark.parametrize(
@ -104,7 +107,7 @@ async def test_form_errors(hass: HomeAssistant, side_effect, error) -> None:
) )
with patch( with patch(
"homeassistant.components.google_generative_ai_conversation.config_flow.palm.list_models", "homeassistant.components.google_generative_ai_conversation.config_flow.genai.list_models",
side_effect=side_effect, side_effect=side_effect,
): ):
result2 = await hass.config_entries.flow.async_configure( result2 = await hass.config_entries.flow.async_configure(

View File

@ -1,11 +1,13 @@
"""Tests for the Google Generative AI Conversation integration.""" """Tests for the Google Generative AI Conversation integration."""
from unittest.mock import patch from unittest.mock import AsyncMock, MagicMock, patch
from google.api_core.exceptions import ClientError from google.api_core.exceptions import ClientError
import pytest
from syrupy.assertion import SnapshotAssertion from syrupy.assertion import SnapshotAssertion
from homeassistant.components import conversation from homeassistant.components import conversation
from homeassistant.core import Context, HomeAssistant from homeassistant.core import Context, HomeAssistant
from homeassistant.exceptions import HomeAssistantError
from homeassistant.helpers import area_registry as ar, device_registry as dr, intent from homeassistant.helpers import area_registry as ar, device_registry as dr, intent
from tests.common import MockConfigEntry from tests.common import MockConfigEntry
@ -91,20 +93,24 @@ async def test_default_prompt(
model=3, model=3,
suggested_area="Test Area 2", suggested_area="Test Area 2",
) )
with patch("google.generativeai.chat_async") as mock_chat: with patch("google.generativeai.GenerativeModel") as mock_model:
mock_model.return_value.start_chat.return_value = AsyncMock()
result = await conversation.async_converse( result = await conversation.async_converse(
hass, "hello", None, Context(), agent_id=mock_config_entry.entry_id hass, "hello", None, Context(), agent_id=mock_config_entry.entry_id
) )
assert result.response.response_type == intent.IntentResponseType.ACTION_DONE assert result.response.response_type == intent.IntentResponseType.ACTION_DONE
assert mock_chat.mock_calls[0][2] == snapshot assert [tuple(mock_call) for mock_call in mock_model.mock_calls] == snapshot
async def test_error_handling( async def test_error_handling(
hass: HomeAssistant, mock_config_entry: MockConfigEntry, mock_init_component hass: HomeAssistant, mock_config_entry: MockConfigEntry, mock_init_component
) -> None: ) -> None:
"""Test that the default prompt works.""" """Test that the default prompt works."""
with patch("google.generativeai.chat_async", side_effect=ClientError("")): with patch("google.generativeai.GenerativeModel") as mock_model:
mock_chat = AsyncMock()
mock_model.return_value.start_chat.return_value = mock_chat
mock_chat.send_message_async.side_effect = ClientError("")
result = await conversation.async_converse( result = await conversation.async_converse(
hass, "hello", None, Context(), agent_id=mock_config_entry.entry_id hass, "hello", None, Context(), agent_id=mock_config_entry.entry_id
) )
@ -125,7 +131,7 @@ async def test_template_error(
) )
with patch( with patch(
"google.generativeai.get_model", "google.generativeai.get_model",
), patch("google.generativeai.chat_async"): ), patch("google.generativeai.GenerativeModel"):
await hass.config_entries.async_setup(mock_config_entry.entry_id) await hass.config_entries.async_setup(mock_config_entry.entry_id)
await hass.async_block_till_done() await hass.async_block_till_done()
result = await conversation.async_converse( result = await conversation.async_converse(
@ -146,3 +152,168 @@ async def test_conversation_agent(
mock_config_entry.entry_id mock_config_entry.entry_id
) )
assert agent.supported_languages == "*" assert agent.supported_languages == "*"
async def test_generate_content_service_without_images(
hass: HomeAssistant,
mock_config_entry: MockConfigEntry,
mock_init_component,
snapshot: SnapshotAssertion,
) -> None:
"""Test generate content service."""
stubbed_generated_content = (
"I'm thrilled to welcome you all to the release "
+ "party for the latest version of Home Assistant!"
)
with patch("google.generativeai.GenerativeModel") as mock_model:
mock_response = MagicMock()
mock_response.text = stubbed_generated_content
mock_model.return_value.generate_content_async = AsyncMock(
return_value=mock_response
)
response = await hass.services.async_call(
"google_generative_ai_conversation",
"generate_content",
{"prompt": "Write an opening speech for a Home Assistant release party"},
blocking=True,
return_response=True,
)
assert response == {
"text": stubbed_generated_content,
}
assert [tuple(mock_call) for mock_call in mock_model.mock_calls] == snapshot
async def test_generate_content_service_with_image(
hass: HomeAssistant,
mock_config_entry: MockConfigEntry,
mock_init_component,
snapshot: SnapshotAssertion,
) -> None:
"""Test generate content service."""
stubbed_generated_content = (
"A mail carrier is at your front door delivering a package"
)
with patch("google.generativeai.GenerativeModel") as mock_model, patch(
"homeassistant.components.google_generative_ai_conversation.Path.read_bytes",
return_value=b"image bytes",
), patch("pathlib.Path.exists", return_value=True), patch.object(
hass.config, "is_allowed_path", return_value=True
):
mock_response = MagicMock()
mock_response.text = stubbed_generated_content
mock_model.return_value.generate_content_async = AsyncMock(
return_value=mock_response
)
response = await hass.services.async_call(
"google_generative_ai_conversation",
"generate_content",
{
"prompt": "Describe this image from my doorbell camera",
"image_filename": "doorbell_snapshot.jpg",
},
blocking=True,
return_response=True,
)
assert response == {
"text": stubbed_generated_content,
}
assert [tuple(mock_call) for mock_call in mock_model.mock_calls] == snapshot
@pytest.mark.usefixtures("mock_init_component")
async def test_generate_content_service_error(
hass: HomeAssistant,
mock_config_entry: MockConfigEntry,
) -> None:
"""Test generate content service handles errors."""
with patch("google.generativeai.GenerativeModel") as mock_model, pytest.raises(
HomeAssistantError, match="Error generating content: None reason"
):
mock_model.return_value.generate_content_async = AsyncMock(
side_effect=ClientError("reason")
)
await hass.services.async_call(
"google_generative_ai_conversation",
"generate_content",
{"prompt": "write a story about an epic fail"},
blocking=True,
return_response=True,
)
async def test_generate_content_service_with_image_not_allowed_path(
hass: HomeAssistant,
mock_config_entry: MockConfigEntry,
mock_init_component,
snapshot: SnapshotAssertion,
) -> None:
"""Test generate content service with an image in a not allowed path."""
with patch("pathlib.Path.exists", return_value=True), patch.object(
hass.config, "is_allowed_path", return_value=False
), pytest.raises(
HomeAssistantError,
match="Cannot read `doorbell_snapshot.jpg`, no access to path; `allowlist_external_dirs` may need to be adjusted in `configuration.yaml`",
):
await hass.services.async_call(
"google_generative_ai_conversation",
"generate_content",
{
"prompt": "Describe this image from my doorbell camera",
"image_filename": "doorbell_snapshot.jpg",
},
blocking=True,
return_response=True,
)
async def test_generate_content_service_with_image_not_exists(
hass: HomeAssistant,
mock_config_entry: MockConfigEntry,
mock_init_component,
snapshot: SnapshotAssertion,
) -> None:
"""Test generate content service with an image that does not exist."""
with patch("pathlib.Path.exists", return_value=True), patch.object(
hass.config, "is_allowed_path", return_value=True
), patch("pathlib.Path.exists", return_value=False), pytest.raises(
HomeAssistantError, match="`doorbell_snapshot.jpg` does not exist"
):
await hass.services.async_call(
"google_generative_ai_conversation",
"generate_content",
{
"prompt": "Describe this image from my doorbell camera",
"image_filename": "doorbell_snapshot.jpg",
},
blocking=True,
return_response=True,
)
async def test_generate_content_service_with_non_image(
hass: HomeAssistant,
mock_config_entry: MockConfigEntry,
mock_init_component,
snapshot: SnapshotAssertion,
) -> None:
"""Test generate content service with a non image."""
with patch("pathlib.Path.exists", return_value=True), patch.object(
hass.config, "is_allowed_path", return_value=True
), patch("pathlib.Path.exists", return_value=True), pytest.raises(
HomeAssistantError, match="`doorbell_snapshot.mp4` is not an image"
):
await hass.services.async_call(
"google_generative_ai_conversation",
"generate_content",
{
"prompt": "Describe this image from my doorbell camera",
"image_filename": "doorbell_snapshot.mp4",
},
blocking=True,
return_response=True,
)