LLM Tools support for Google Generative AI integration (#117644)

* initial commit

* Undo prompt chenges

* Move format_tool out of the class

* Only catch HomeAssistantError and vol.Invalid

* Add config flow option

* Fix type

* Add translation

* Allow changing API access from options flow

* Allow model picking

* Remove allowing HASS Access in main flow

* Move model to the top in options flow

* Make prompt conditional based on API access

* convert only once to dict

* Reduce debug logging

* Update title

* re-order models

* Address comments

* Move things

* Update labels

* Add tool call tests

* coverage

* Use LLM APIs

* Fixes

* Address comments

* Reinstate the title to not break entity name

---------

Co-authored-by: Paulus Schoutsen <balloob@gmail.com>
This commit is contained in:
Denis Shulyaka 2024-05-20 05:11:25 +03:00 committed by GitHub
parent ac3321cef1
commit c3196a5667
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
17 changed files with 588 additions and 119 deletions

View File

@ -23,7 +23,7 @@ from homeassistant.exceptions import ConfigEntryNotReady, HomeAssistantError
from homeassistant.helpers import config_validation as cv
from homeassistant.helpers.typing import ConfigType
from .const import CONF_CHAT_MODEL, CONF_PROMPT, DEFAULT_CHAT_MODEL, DOMAIN, LOGGER
from .const import CONF_PROMPT, DOMAIN, LOGGER
SERVICE_GENERATE_CONTENT = "generate_content"
CONF_IMAGE_FILENAME = "image_filename"
@ -97,11 +97,7 @@ async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
genai.configure(api_key=entry.data[CONF_API_KEY])
try:
await hass.async_add_executor_job(
partial(
genai.get_model, entry.options.get(CONF_CHAT_MODEL, DEFAULT_CHAT_MODEL)
)
)
await hass.async_add_executor_job(partial(genai.list_models))
except ClientError as err:
if err.reason == "API_KEY_INVALID":
LOGGER.error("Invalid API key: %s", err)

View File

@ -4,7 +4,6 @@ from __future__ import annotations
from functools import partial
import logging
import types
from types import MappingProxyType
from typing import Any
@ -18,11 +17,15 @@ from homeassistant.config_entries import (
ConfigFlowResult,
OptionsFlow,
)
from homeassistant.const import CONF_API_KEY
from homeassistant.const import CONF_API_KEY, CONF_LLM_HASS_API
from homeassistant.core import HomeAssistant
from homeassistant.helpers import llm
from homeassistant.helpers.selector import (
NumberSelector,
NumberSelectorConfig,
SelectOptionDict,
SelectSelector,
SelectSelectorConfig,
TemplateSelector,
)
@ -50,17 +53,6 @@ STEP_USER_DATA_SCHEMA = vol.Schema(
}
)
DEFAULT_OPTIONS = types.MappingProxyType(
{
CONF_PROMPT: DEFAULT_PROMPT,
CONF_CHAT_MODEL: DEFAULT_CHAT_MODEL,
CONF_TEMPERATURE: DEFAULT_TEMPERATURE,
CONF_TOP_P: DEFAULT_TOP_P,
CONF_TOP_K: DEFAULT_TOP_K,
CONF_MAX_TOKENS: DEFAULT_MAX_TOKENS,
}
)
async def validate_input(hass: HomeAssistant, data: dict[str, Any]) -> None:
"""Validate the user input allows us to connect.
@ -99,7 +91,9 @@ class GoogleGenerativeAIConfigFlow(ConfigFlow, domain=DOMAIN):
errors["base"] = "unknown"
else:
return self.async_create_entry(
title="Google Generative AI Conversation", data=user_input
title="Google Generative AI",
data=user_input,
options={CONF_LLM_HASS_API: llm.LLM_API_ASSIST},
)
return self.async_show_form(
@ -126,53 +120,96 @@ class GoogleGenerativeAIOptionsFlow(OptionsFlow):
) -> ConfigFlowResult:
"""Manage the options."""
if user_input is not None:
return self.async_create_entry(
title="Google Generative AI Conversation", data=user_input
)
schema = google_generative_ai_config_option_schema(self.config_entry.options)
if user_input[CONF_LLM_HASS_API] == "none":
user_input.pop(CONF_LLM_HASS_API)
return self.async_create_entry(title="", data=user_input)
schema = await google_generative_ai_config_option_schema(
self.hass, self.config_entry.options
)
return self.async_show_form(
step_id="init",
data_schema=vol.Schema(schema),
)
def google_generative_ai_config_option_schema(
async def google_generative_ai_config_option_schema(
hass: HomeAssistant,
options: MappingProxyType[str, Any],
) -> dict:
"""Return a schema for Google Generative AI completion options."""
if not options:
options = DEFAULT_OPTIONS
api_models = await hass.async_add_executor_job(partial(genai.list_models))
models: list[SelectOptionDict] = [
SelectOptionDict(
label="Gemini 1.5 Flash (recommended)",
value="models/gemini-1.5-flash-latest",
),
]
models.extend(
SelectOptionDict(
label=api_model.display_name,
value=api_model.name,
)
for api_model in sorted(api_models, key=lambda x: x.display_name)
if (
api_model.name
not in (
"models/gemini-1.0-pro", # duplicate of gemini-pro
"models/gemini-1.5-flash-latest",
)
and "vision" not in api_model.name
and "generateContent" in api_model.supported_generation_methods
)
)
apis: list[SelectOptionDict] = [
SelectOptionDict(
label="No control",
value="none",
)
]
apis.extend(
SelectOptionDict(
label=api.name,
value=api.id,
)
for api in llm.async_get_apis(hass)
)
return {
vol.Optional(
CONF_CHAT_MODEL,
description={"suggested_value": options.get(CONF_CHAT_MODEL)},
default=DEFAULT_CHAT_MODEL,
): SelectSelector(SelectSelectorConfig(options=models)),
vol.Optional(
CONF_LLM_HASS_API,
description={"suggested_value": options.get(CONF_LLM_HASS_API)},
default="none",
): SelectSelector(SelectSelectorConfig(options=apis)),
vol.Optional(
CONF_PROMPT,
description={"suggested_value": options[CONF_PROMPT]},
description={"suggested_value": options.get(CONF_PROMPT)},
default=DEFAULT_PROMPT,
): TemplateSelector(),
vol.Optional(
CONF_CHAT_MODEL,
description={
"suggested_value": options.get(CONF_CHAT_MODEL, DEFAULT_CHAT_MODEL)
},
default=DEFAULT_CHAT_MODEL,
): str,
vol.Optional(
CONF_TEMPERATURE,
description={"suggested_value": options[CONF_TEMPERATURE]},
description={"suggested_value": options.get(CONF_TEMPERATURE)},
default=DEFAULT_TEMPERATURE,
): NumberSelector(NumberSelectorConfig(min=0, max=1, step=0.05)),
vol.Optional(
CONF_TOP_P,
description={"suggested_value": options[CONF_TOP_P]},
description={"suggested_value": options.get(CONF_TOP_P)},
default=DEFAULT_TOP_P,
): NumberSelector(NumberSelectorConfig(min=0, max=1, step=0.05)),
vol.Optional(
CONF_TOP_K,
description={"suggested_value": options[CONF_TOP_K]},
description={"suggested_value": options.get(CONF_TOP_K)},
default=DEFAULT_TOP_K,
): int,
vol.Optional(
CONF_MAX_TOKENS,
description={"suggested_value": options[CONF_MAX_TOKENS]},
description={"suggested_value": options.get(CONF_MAX_TOKENS)},
default=DEFAULT_MAX_TOKENS,
): int,
}

View File

@ -21,11 +21,8 @@ An overview of the areas and the devices in this smart home:
{%- endif %}
{%- endfor %}
{%- endfor %}
Answer the user's questions about the world truthfully.
If the user wants to control a device, reject the request and suggest using the Home Assistant app.
"""
CONF_CHAT_MODEL = "chat_model"
DEFAULT_CHAT_MODEL = "models/gemini-pro"
CONF_TEMPERATURE = "temperature"
@ -36,3 +33,4 @@ CONF_TOP_K = "top_k"
DEFAULT_TOP_K = 1
CONF_MAX_TOKENS = "max_tokens"
DEFAULT_MAX_TOKENS = 150
DEFAULT_ALLOW_HASS_ACCESS = False

View File

@ -2,18 +2,21 @@
from __future__ import annotations
from typing import Literal
from typing import Any, Literal
import google.ai.generativelanguage as glm
from google.api_core.exceptions import ClientError
import google.generativeai as genai
import google.generativeai.types as genai_types
import voluptuous as vol
from voluptuous_openapi import convert
from homeassistant.components import assist_pipeline, conversation
from homeassistant.config_entries import ConfigEntry
from homeassistant.const import MATCH_ALL
from homeassistant.const import CONF_LLM_HASS_API, MATCH_ALL
from homeassistant.core import HomeAssistant
from homeassistant.exceptions import TemplateError
from homeassistant.helpers import intent, template
from homeassistant.exceptions import HomeAssistantError, TemplateError
from homeassistant.helpers import intent, llm, template
from homeassistant.helpers.entity_platform import AddEntitiesCallback
from homeassistant.util import ulid
@ -30,9 +33,13 @@ from .const import (
DEFAULT_TEMPERATURE,
DEFAULT_TOP_K,
DEFAULT_TOP_P,
DOMAIN,
LOGGER,
)
# Max number of back and forth with the LLM to generate a response
MAX_TOOL_ITERATIONS = 10
async def async_setup_entry(
hass: HomeAssistant,
@ -44,6 +51,55 @@ async def async_setup_entry(
async_add_entities([agent])
SUPPORTED_SCHEMA_KEYS = {
"type",
"format",
"description",
"nullable",
"enum",
"items",
"properties",
"required",
}
def _format_schema(schema: dict[str, Any]) -> dict[str, Any]:
"""Format the schema to protobuf."""
result = {}
for key, val in schema.items():
if key not in SUPPORTED_SCHEMA_KEYS:
continue
if key == "type":
key = "type_"
val = val.upper()
elif key == "format":
key = "format_"
elif key == "items":
val = _format_schema(val)
elif key == "properties":
val = {k: _format_schema(v) for k, v in val.items()}
result[key] = val
return result
def _format_tool(tool: llm.Tool) -> dict[str, Any]:
"""Format tool specification."""
parameters = _format_schema(convert(tool.parameters))
return glm.Tool(
{
"function_declarations": [
{
"name": tool.name,
"description": tool.description,
"parameters": parameters,
}
]
}
)
class GoogleGenerativeAIConversationEntity(
conversation.ConversationEntity, conversation.AbstractConversationAgent
):
@ -80,6 +136,26 @@ class GoogleGenerativeAIConversationEntity(
self, user_input: conversation.ConversationInput
) -> conversation.ConversationResult:
"""Process a sentence."""
intent_response = intent.IntentResponse(language=user_input.language)
llm_api: llm.API | None = None
tools: list[dict[str, Any]] | None = None
if self.entry.options.get(CONF_LLM_HASS_API):
try:
llm_api = llm.async_get_api(
self.hass, self.entry.options[CONF_LLM_HASS_API]
)
except HomeAssistantError as err:
LOGGER.error("Error getting LLM API: %s", err)
intent_response.async_set_error(
intent.IntentResponseErrorCode.UNKNOWN,
f"Error preparing LLM API: {err}",
)
return conversation.ConversationResult(
response=intent_response, conversation_id=user_input.conversation_id
)
tools = [_format_tool(tool) for tool in llm_api.async_get_tools()]
raw_prompt = self.entry.options.get(CONF_PROMPT, DEFAULT_PROMPT)
model = genai.GenerativeModel(
model_name=self.entry.options.get(CONF_CHAT_MODEL, DEFAULT_CHAT_MODEL),
@ -93,8 +169,8 @@ class GoogleGenerativeAIConversationEntity(
CONF_MAX_TOKENS, DEFAULT_MAX_TOKENS
),
},
tools=tools or None,
)
LOGGER.debug("Model: %s", model)
if user_input.conversation_id in self.history:
conversation_id = user_input.conversation_id
@ -103,9 +179,8 @@ class GoogleGenerativeAIConversationEntity(
conversation_id = ulid.ulid_now()
messages = [{}, {}]
intent_response = intent.IntentResponse(language=user_input.language)
try:
prompt = self._async_generate_prompt(raw_prompt)
prompt = self._async_generate_prompt(raw_prompt, llm_api)
except TemplateError as err:
LOGGER.error("Error rendering prompt: %s", err)
intent_response.async_set_error(
@ -122,40 +197,84 @@ class GoogleGenerativeAIConversationEntity(
LOGGER.debug("Input: '%s' with history: %s", user_input.text, messages)
chat = model.start_chat(history=messages)
try:
chat_response = await chat.send_message_async(user_input.text)
except (
ClientError,
ValueError,
genai_types.BlockedPromptException,
genai_types.StopCandidateException,
) as err:
LOGGER.error("Error sending message: %s", err)
intent_response.async_set_error(
intent.IntentResponseErrorCode.UNKNOWN,
f"Sorry, I had a problem talking to Google Generative AI: {err}",
chat_request = user_input.text
# To prevent infinite loops, we limit the number of iterations
for _iteration in range(MAX_TOOL_ITERATIONS):
try:
chat_response = await chat.send_message_async(chat_request)
except (
ClientError,
ValueError,
genai_types.BlockedPromptException,
genai_types.StopCandidateException,
) as err:
LOGGER.error("Error sending message: %s", err)
intent_response.async_set_error(
intent.IntentResponseErrorCode.UNKNOWN,
f"Sorry, I had a problem talking to Google Generative AI: {err}",
)
return conversation.ConversationResult(
response=intent_response, conversation_id=conversation_id
)
LOGGER.debug("Response: %s", chat_response.parts)
if not chat_response.parts:
intent_response.async_set_error(
intent.IntentResponseErrorCode.UNKNOWN,
"Sorry, I had a problem talking to Google Generative AI. Likely blocked",
)
return conversation.ConversationResult(
response=intent_response, conversation_id=conversation_id
)
self.history[conversation_id] = chat.history
tool_call = chat_response.parts[0].function_call
if not tool_call or not llm_api:
break
tool_input = llm.ToolInput(
tool_name=tool_call.name,
tool_args=dict(tool_call.args),
platform=DOMAIN,
context=user_input.context,
user_prompt=user_input.text,
language=user_input.language,
assistant=conversation.DOMAIN,
)
return conversation.ConversationResult(
response=intent_response, conversation_id=conversation_id
LOGGER.debug(
"Tool call: %s(%s)", tool_input.tool_name, tool_input.tool_args
)
try:
function_response = await llm_api.async_call_tool(tool_input)
except (HomeAssistantError, vol.Invalid) as e:
function_response = {"error": type(e).__name__}
if str(e):
function_response["error_text"] = str(e)
LOGGER.debug("Tool response: %s", function_response)
chat_request = glm.Content(
parts=[
glm.Part(
function_response=glm.FunctionResponse(
name=tool_call.name, response=function_response
)
)
]
)
LOGGER.debug("Response: %s", chat_response.parts)
if not chat_response.parts:
intent_response.async_set_error(
intent.IntentResponseErrorCode.UNKNOWN,
"Sorry, I had a problem talking to Google Generative AI. Likely blocked",
)
return conversation.ConversationResult(
response=intent_response, conversation_id=conversation_id
)
self.history[conversation_id] = chat.history
intent_response.async_set_speech(chat_response.text)
return conversation.ConversationResult(
response=intent_response, conversation_id=conversation_id
)
def _async_generate_prompt(self, raw_prompt: str) -> str:
def _async_generate_prompt(self, raw_prompt: str, llm_api: llm.API | None) -> str:
"""Generate a prompt for the user."""
raw_prompt += "\n"
if llm_api:
raw_prompt += llm_api.prompt_template
else:
raw_prompt += llm.PROMPT_NO_API_CONFIGURED
return template.Template(raw_prompt, self.hass).async_render(
{
"ha_name": self.hass.config.location_name,

View File

@ -1,12 +1,12 @@
{
"domain": "google_generative_ai_conversation",
"name": "Google Generative AI Conversation",
"after_dependencies": ["assist_pipeline"],
"name": "Google Generative AI",
"after_dependencies": ["assist_pipeline", "intent"],
"codeowners": ["@tronikos"],
"config_flow": true,
"dependencies": ["conversation"],
"documentation": "https://www.home-assistant.io/integrations/google_generative_ai_conversation",
"integration_type": "service",
"iot_class": "cloud_polling",
"requirements": ["google-generativeai==0.5.4"]
"requirements": ["google-generativeai==0.5.4", "voluptuous-openapi==0.0.3"]
}

View File

@ -3,7 +3,8 @@
"step": {
"user": {
"data": {
"api_key": "[%key:common::config_flow::data::api_key%]"
"api_key": "[%key:common::config_flow::data::api_key%]",
"llm_hass_api": "[%key:common::config_flow::data::llm_hass_api%]"
}
}
},
@ -18,11 +19,12 @@
"init": {
"data": {
"prompt": "Prompt Template",
"model": "[%key:common::generic::model%]",
"chat_model": "[%key:common::generic::model%]",
"temperature": "Temperature",
"top_p": "Top P",
"top_k": "Top K",
"max_tokens": "Maximum tokens to return in response"
"max_tokens": "Maximum tokens to return in response",
"llm_hass_api": "[%key:common::config_flow::data::llm_hass_api%]"
}
}
}

View File

@ -113,6 +113,7 @@ CONF_ACCESS_TOKEN: Final = "access_token"
CONF_ADDRESS: Final = "address"
CONF_AFTER: Final = "after"
CONF_ALIAS: Final = "alias"
CONF_LLM_HASS_API = "llm_hass_api"
CONF_ALLOWLIST_EXTERNAL_URLS: Final = "allowlist_external_urls"
CONF_API_KEY: Final = "api_key"
CONF_API_TOKEN: Final = "api_token"

View File

@ -2271,7 +2271,7 @@
"integration_type": "service",
"config_flow": true,
"iot_class": "cloud_polling",
"name": "Google Generative AI Conversation"
"name": "Google Generative AI"
},
"google_mail": {
"integration_type": "service",

View File

@ -17,18 +17,17 @@ from homeassistant.util.json import JsonObjectType
from . import intent
from .singleton import singleton
LLM_API_ASSIST = "assist"
PROMPT_NO_API_CONFIGURED = "If the user wants to control a device, tell them to edit the AI configuration and allow access to Home Assistant."
@singleton("llm")
@callback
def _async_get_apis(hass: HomeAssistant) -> dict[str, API]:
"""Get all the LLM APIs."""
return {
"assist": AssistAPI(
hass=hass,
id="assist",
name="Assist",
prompt_template="Call the intent tools to control the system. Just pass the name to the intent.",
),
LLM_API_ASSIST: AssistAPI(hass=hass),
}
@ -170,6 +169,15 @@ class AssistAPI(API):
INTENT_GET_TEMPERATURE,
}
def __init__(self, hass: HomeAssistant) -> None:
"""Init the class."""
super().__init__(
hass=hass,
id=LLM_API_ASSIST,
name="Assist",
prompt_template="Call the intent tools to control the system. Just pass the name to the intent.",
)
@callback
def async_get_tools(self) -> list[Tool]:
"""Return a list of LLM tools."""

View File

@ -88,6 +88,7 @@
"access_token": "Access token",
"api_key": "API key",
"api_token": "API token",
"llm_hass_api": "Control Home Assistant",
"ssl": "Uses an SSL certificate",
"verify_ssl": "Verify SSL certificate",
"elevation": "Elevation",

View File

@ -2825,6 +2825,9 @@ voip-utils==0.1.0
# homeassistant.components.volkszaehler
volkszaehler==0.4.0
# homeassistant.components.google_generative_ai_conversation
voluptuous-openapi==0.0.3
# homeassistant.components.volvooncall
volvooncall==0.10.3

View File

@ -2190,6 +2190,9 @@ vilfo-api-client==0.5.0
# homeassistant.components.voip
voip-utils==0.1.0
# homeassistant.components.google_generative_ai_conversation
voluptuous-openapi==0.0.3
# homeassistant.components.volvooncall
volvooncall==0.10.3

View File

@ -5,7 +5,9 @@ from unittest.mock import patch
import pytest
from homeassistant.config_entries import ConfigEntry
from homeassistant.const import CONF_LLM_HASS_API
from homeassistant.core import HomeAssistant
from homeassistant.helpers import llm
from homeassistant.setup import async_setup_component
from tests.common import MockConfigEntry
@ -25,6 +27,15 @@ def mock_config_entry(hass):
return entry
@pytest.fixture
def mock_config_entry_with_assist(hass, mock_config_entry):
"""Mock a config entry with assist."""
hass.config_entries.async_update_entry(
mock_config_entry, options={CONF_LLM_HASS_API: llm.LLM_API_ASSIST}
)
return mock_config_entry
@pytest.fixture
async def mock_init_component(hass: HomeAssistant, mock_config_entry: ConfigEntry):
"""Initialize integration."""

View File

@ -1,5 +1,5 @@
# serializer version: 1
# name: test_default_prompt[None]
# name: test_default_prompt[False-None]
list([
tuple(
'',
@ -13,6 +13,7 @@
'top_p': 1.0,
}),
'model_name': 'models/gemini-pro',
'tools': None,
}),
),
tuple(
@ -36,9 +37,7 @@
- Test Device 4
- 1 (3)
Answer the user's questions about the world truthfully.
If the user wants to control a device, reject the request and suggest using the Home Assistant app.
If the user wants to control a device, tell them to edit the AI configuration and allow access to Home Assistant.
''',
'role': 'user',
}),
@ -59,7 +58,7 @@
),
])
# ---
# name: test_default_prompt[conversation.google_generative_ai_conversation]
# name: test_default_prompt[False-conversation.google_generative_ai_conversation]
list([
tuple(
'',
@ -73,6 +72,7 @@
'top_p': 1.0,
}),
'model_name': 'models/gemini-pro',
'tools': None,
}),
),
tuple(
@ -96,9 +96,7 @@
- Test Device 4
- 1 (3)
Answer the user's questions about the world truthfully.
If the user wants to control a device, reject the request and suggest using the Home Assistant app.
If the user wants to control a device, tell them to edit the AI configuration and allow access to Home Assistant.
''',
'role': 'user',
}),
@ -119,48 +117,118 @@
),
])
# ---
# name: test_generate_content_service_with_image
# name: test_default_prompt[True-None]
list([
tuple(
'',
tuple(
),
dict({
'model_name': 'gemini-pro-vision',
'generation_config': dict({
'max_output_tokens': 150,
'temperature': 0.9,
'top_k': 1,
'top_p': 1.0,
}),
'model_name': 'models/gemini-pro',
'tools': None,
}),
),
tuple(
'().generate_content_async',
'().start_chat',
tuple(
list([
'Describe this image from my doorbell camera',
),
dict({
'history': list([
dict({
'data': b'image bytes',
'mime_type': 'image/jpeg',
'parts': '''
This smart home is controlled by Home Assistant.
An overview of the areas and the devices in this smart home:
Test Area:
- Test Device (Test Model)
Test Area 2:
- Test Device 2
- Test Device 3 (Test Model 3A)
- Test Device 4
- 1 (3)
Call the intent tools to control the system. Just pass the name to the intent.
''',
'role': 'user',
}),
dict({
'parts': 'Ok',
'role': 'model',
}),
]),
}),
),
tuple(
'().start_chat().send_message_async',
tuple(
'hello',
),
dict({
}),
),
])
# ---
# name: test_generate_content_service_without_images
# name: test_default_prompt[True-conversation.google_generative_ai_conversation]
list([
tuple(
'',
tuple(
),
dict({
'model_name': 'gemini-pro',
'generation_config': dict({
'max_output_tokens': 150,
'temperature': 0.9,
'top_k': 1,
'top_p': 1.0,
}),
'model_name': 'models/gemini-pro',
'tools': None,
}),
),
tuple(
'().generate_content_async',
'().start_chat',
tuple(
list([
'Write an opening speech for a Home Assistant release party',
),
dict({
'history': list([
dict({
'parts': '''
This smart home is controlled by Home Assistant.
An overview of the areas and the devices in this smart home:
Test Area:
- Test Device (Test Model)
Test Area 2:
- Test Device 2
- Test Device 3 (Test Model 3A)
- Test Device 4
- 1 (3)
Call the intent tools to control the system. Just pass the name to the intent.
''',
'role': 'user',
}),
dict({
'parts': 'Ok',
'role': 'model',
}),
]),
}),
),
tuple(
'().start_chat().send_message_async',
tuple(
'hello',
),
dict({
}),

View File

@ -1,6 +1,6 @@
"""Test the Google Generative AI Conversation config flow."""
from unittest.mock import patch
from unittest.mock import Mock, patch
from google.api_core.exceptions import ClientError
from google.rpc.error_details_pb2 import ErrorInfo
@ -18,12 +18,35 @@ from homeassistant.components.google_generative_ai_conversation.const import (
DEFAULT_TOP_P,
DOMAIN,
)
from homeassistant.const import CONF_LLM_HASS_API
from homeassistant.core import HomeAssistant
from homeassistant.data_entry_flow import FlowResultType
from homeassistant.helpers import llm
from tests.common import MockConfigEntry
@pytest.fixture
def mock_models():
"""Mock the model list API."""
model_15_flash = Mock(
display_name="Gemini 1.5 Flash",
supported_generation_methods=["generateContent"],
)
model_15_flash.name = "models/gemini-1.5-flash-latest"
model_10_pro = Mock(
display_name="Gemini 1.0 Pro",
supported_generation_methods=["generateContent"],
)
model_10_pro.name = "models/gemini-pro"
with patch(
"homeassistant.components.google_generative_ai_conversation.config_flow.genai.list_models",
return_value=[model_10_pro],
):
yield
async def test_form(hass: HomeAssistant) -> None:
"""Test we get the form."""
# Pretend we already set up a config entry.
@ -60,11 +83,14 @@ async def test_form(hass: HomeAssistant) -> None:
assert result2["data"] == {
"api_key": "bla",
}
assert result2["options"] == {
CONF_LLM_HASS_API: llm.LLM_API_ASSIST,
}
assert len(mock_setup_entry.mock_calls) == 1
async def test_options(
hass: HomeAssistant, mock_config_entry, mock_init_component
hass: HomeAssistant, mock_config_entry, mock_init_component, mock_models
) -> None:
"""Test the options form."""
options_flow = await hass.config_entries.options.async_init(
@ -85,6 +111,9 @@ async def test_options(
assert options["data"][CONF_TOP_P] == DEFAULT_TOP_P
assert options["data"][CONF_TOP_K] == DEFAULT_TOP_K
assert options["data"][CONF_MAX_TOKENS] == DEFAULT_MAX_TOKENS
assert (
CONF_LLM_HASS_API not in options["data"]
), "Options flow should not set this key"
@pytest.mark.parametrize(

View File

@ -5,10 +5,18 @@ from unittest.mock import AsyncMock, MagicMock, patch
from google.api_core.exceptions import ClientError
import pytest
from syrupy.assertion import SnapshotAssertion
import voluptuous as vol
from homeassistant.components import conversation
from homeassistant.const import CONF_LLM_HASS_API
from homeassistant.core import Context, HomeAssistant
from homeassistant.helpers import area_registry as ar, device_registry as dr, intent
from homeassistant.exceptions import HomeAssistantError
from homeassistant.helpers import (
area_registry as ar,
device_registry as dr,
intent,
llm,
)
from tests.common import MockConfigEntry
@ -16,6 +24,7 @@ from tests.common import MockConfigEntry
@pytest.mark.parametrize(
"agent_id", [None, "conversation.google_generative_ai_conversation"]
)
@pytest.mark.parametrize("allow_hass_access", [False, True])
async def test_default_prompt(
hass: HomeAssistant,
mock_config_entry: MockConfigEntry,
@ -24,6 +33,7 @@ async def test_default_prompt(
device_registry: dr.DeviceRegistry,
snapshot: SnapshotAssertion,
agent_id: str | None,
allow_hass_access: bool,
) -> None:
"""Test that the default prompt works."""
entry = MockConfigEntry(title=None)
@ -34,6 +44,15 @@ async def test_default_prompt(
if agent_id is None:
agent_id = mock_config_entry.entry_id
if allow_hass_access:
hass.config_entries.async_update_entry(
mock_config_entry,
options={
**mock_config_entry.options,
CONF_LLM_HASS_API: llm.LLM_API_ASSIST,
},
)
device_registry.async_get_or_create(
config_entry_id=entry.entry_id,
connections={("test", "1234")},
@ -100,12 +119,20 @@ async def test_default_prompt(
model=3,
suggested_area="Test Area 2",
)
with patch("google.generativeai.GenerativeModel") as mock_model:
with (
patch("google.generativeai.GenerativeModel") as mock_model,
patch(
"homeassistant.components.google_generative_ai_conversation.conversation.llm.AssistAPI.async_get_tools",
return_value=[],
) as mock_get_tools,
):
mock_chat = AsyncMock()
mock_model.return_value.start_chat.return_value = mock_chat
chat_response = MagicMock()
mock_chat.send_message_async.return_value = chat_response
chat_response.parts = ["Hi there!"]
mock_part = MagicMock()
mock_part.function_call = None
chat_response.parts = [mock_part]
chat_response.text = "Hi there!"
result = await conversation.async_converse(
hass,
@ -118,6 +145,171 @@ async def test_default_prompt(
assert result.response.response_type == intent.IntentResponseType.ACTION_DONE
assert result.response.as_dict()["speech"]["plain"]["speech"] == "Hi there!"
assert [tuple(mock_call) for mock_call in mock_model.mock_calls] == snapshot
assert mock_get_tools.called == allow_hass_access
@patch(
"homeassistant.components.google_generative_ai_conversation.conversation.llm.AssistAPI.async_get_tools"
)
async def test_function_call(
mock_get_tools,
hass: HomeAssistant,
mock_config_entry_with_assist: MockConfigEntry,
mock_init_component,
) -> None:
"""Test that the default prompt works."""
agent_id = mock_config_entry_with_assist.entry_id
context = Context()
mock_tool = AsyncMock()
mock_tool.name = "test_tool"
mock_tool.description = "Test function"
mock_tool.parameters = vol.Schema(
{
vol.Optional("param1", description="Test parameters"): [
vol.All(str, vol.Lower)
]
}
)
mock_get_tools.return_value = [mock_tool]
with patch("google.generativeai.GenerativeModel") as mock_model:
mock_chat = AsyncMock()
mock_model.return_value.start_chat.return_value = mock_chat
chat_response = MagicMock()
mock_chat.send_message_async.return_value = chat_response
mock_part = MagicMock()
mock_part.function_call.name = "test_tool"
mock_part.function_call.args = {"param1": ["test_value"]}
def tool_call(hass, tool_input):
mock_part.function_call = False
chat_response.text = "Hi there!"
return {"result": "Test response"}
mock_tool.async_call.side_effect = tool_call
chat_response.parts = [mock_part]
result = await conversation.async_converse(
hass,
"Please call the test function",
None,
context,
agent_id=agent_id,
)
assert result.response.response_type == intent.IntentResponseType.ACTION_DONE
assert result.response.as_dict()["speech"]["plain"]["speech"] == "Hi there!"
mock_tool_call = mock_chat.send_message_async.mock_calls[1][1][0]
mock_tool_call = type(mock_tool_call).to_dict(mock_tool_call)
assert mock_tool_call == {
"parts": [
{
"function_response": {
"name": "test_tool",
"response": {
"result": "Test response",
},
},
},
],
"role": "",
}
mock_tool.async_call.assert_awaited_once_with(
hass,
llm.ToolInput(
tool_name="test_tool",
tool_args={"param1": ["test_value"]},
platform="google_generative_ai_conversation",
context=context,
user_prompt="Please call the test function",
language="en",
assistant="conversation",
),
)
@patch(
"homeassistant.components.google_generative_ai_conversation.conversation.llm.AssistAPI.async_get_tools"
)
async def test_function_exception(
mock_get_tools,
hass: HomeAssistant,
mock_config_entry_with_assist: MockConfigEntry,
mock_init_component,
) -> None:
"""Test that the default prompt works."""
agent_id = mock_config_entry_with_assist.entry_id
context = Context()
mock_tool = AsyncMock()
mock_tool.name = "test_tool"
mock_tool.description = "Test function"
mock_tool.parameters = vol.Schema(
{
vol.Optional("param1", description="Test parameters"): vol.All(
vol.Coerce(int), vol.Range(0, 100)
)
}
)
mock_get_tools.return_value = [mock_tool]
with patch("google.generativeai.GenerativeModel") as mock_model:
mock_chat = AsyncMock()
mock_model.return_value.start_chat.return_value = mock_chat
chat_response = MagicMock()
mock_chat.send_message_async.return_value = chat_response
mock_part = MagicMock()
mock_part.function_call.name = "test_tool"
mock_part.function_call.args = {"param1": 1}
def tool_call(hass, tool_input):
mock_part.function_call = False
chat_response.text = "Hi there!"
raise HomeAssistantError("Test tool exception")
mock_tool.async_call.side_effect = tool_call
chat_response.parts = [mock_part]
result = await conversation.async_converse(
hass,
"Please call the test function",
None,
context,
agent_id=agent_id,
)
assert result.response.response_type == intent.IntentResponseType.ACTION_DONE
assert result.response.as_dict()["speech"]["plain"]["speech"] == "Hi there!"
mock_tool_call = mock_chat.send_message_async.mock_calls[1][1][0]
mock_tool_call = type(mock_tool_call).to_dict(mock_tool_call)
assert mock_tool_call == {
"parts": [
{
"function_response": {
"name": "test_tool",
"response": {
"error": "HomeAssistantError",
"error_text": "Test tool exception",
},
},
},
],
"role": "",
}
mock_tool.async_call.assert_awaited_once_with(
hass,
llm.ToolInput(
tool_name="test_tool",
tool_args={"param1": 1},
platform="google_generative_ai_conversation",
context=context,
user_prompt="Please call the test function",
language="en",
assistant="conversation",
),
)
async def test_error_handling(

View File

@ -18,12 +18,13 @@ async def test_get_api_no_existing(hass: HomeAssistant) -> None:
async def test_register_api(hass: HomeAssistant) -> None:
"""Test registering an llm api."""
api = llm.AssistAPI(
hass=hass,
id="test",
name="Test",
prompt_template="Test",
)
class MyAPI(llm.API):
def async_get_tools(self) -> list[llm.Tool]:
"""Return a list of tools."""
return []
api = MyAPI(hass=hass, id="test", name="Test", prompt_template="")
llm.async_register_api(hass, api)
assert llm.async_get_api(hass, "test") is api