LLM Tools support for Google Generative AI integration (#117644)

* initial commit

* Undo prompt chenges

* Move format_tool out of the class

* Only catch HomeAssistantError and vol.Invalid

* Add config flow option

* Fix type

* Add translation

* Allow changing API access from options flow

* Allow model picking

* Remove allowing HASS Access in main flow

* Move model to the top in options flow

* Make prompt conditional based on API access

* convert only once to dict

* Reduce debug logging

* Update title

* re-order models

* Address comments

* Move things

* Update labels

* Add tool call tests

* coverage

* Use LLM APIs

* Fixes

* Address comments

* Reinstate the title to not break entity name

---------

Co-authored-by: Paulus Schoutsen <balloob@gmail.com>
This commit is contained in:
Denis Shulyaka 2024-05-20 05:11:25 +03:00 committed by GitHub
parent ac3321cef1
commit c3196a5667
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
17 changed files with 588 additions and 119 deletions

View File

@ -23,7 +23,7 @@ from homeassistant.exceptions import ConfigEntryNotReady, HomeAssistantError
from homeassistant.helpers import config_validation as cv from homeassistant.helpers import config_validation as cv
from homeassistant.helpers.typing import ConfigType from homeassistant.helpers.typing import ConfigType
from .const import CONF_CHAT_MODEL, CONF_PROMPT, DEFAULT_CHAT_MODEL, DOMAIN, LOGGER from .const import CONF_PROMPT, DOMAIN, LOGGER
SERVICE_GENERATE_CONTENT = "generate_content" SERVICE_GENERATE_CONTENT = "generate_content"
CONF_IMAGE_FILENAME = "image_filename" CONF_IMAGE_FILENAME = "image_filename"
@ -97,11 +97,7 @@ async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
genai.configure(api_key=entry.data[CONF_API_KEY]) genai.configure(api_key=entry.data[CONF_API_KEY])
try: try:
await hass.async_add_executor_job( await hass.async_add_executor_job(partial(genai.list_models))
partial(
genai.get_model, entry.options.get(CONF_CHAT_MODEL, DEFAULT_CHAT_MODEL)
)
)
except ClientError as err: except ClientError as err:
if err.reason == "API_KEY_INVALID": if err.reason == "API_KEY_INVALID":
LOGGER.error("Invalid API key: %s", err) LOGGER.error("Invalid API key: %s", err)

View File

@ -4,7 +4,6 @@ from __future__ import annotations
from functools import partial from functools import partial
import logging import logging
import types
from types import MappingProxyType from types import MappingProxyType
from typing import Any from typing import Any
@ -18,11 +17,15 @@ from homeassistant.config_entries import (
ConfigFlowResult, ConfigFlowResult,
OptionsFlow, OptionsFlow,
) )
from homeassistant.const import CONF_API_KEY from homeassistant.const import CONF_API_KEY, CONF_LLM_HASS_API
from homeassistant.core import HomeAssistant from homeassistant.core import HomeAssistant
from homeassistant.helpers import llm
from homeassistant.helpers.selector import ( from homeassistant.helpers.selector import (
NumberSelector, NumberSelector,
NumberSelectorConfig, NumberSelectorConfig,
SelectOptionDict,
SelectSelector,
SelectSelectorConfig,
TemplateSelector, TemplateSelector,
) )
@ -50,17 +53,6 @@ STEP_USER_DATA_SCHEMA = vol.Schema(
} }
) )
DEFAULT_OPTIONS = types.MappingProxyType(
{
CONF_PROMPT: DEFAULT_PROMPT,
CONF_CHAT_MODEL: DEFAULT_CHAT_MODEL,
CONF_TEMPERATURE: DEFAULT_TEMPERATURE,
CONF_TOP_P: DEFAULT_TOP_P,
CONF_TOP_K: DEFAULT_TOP_K,
CONF_MAX_TOKENS: DEFAULT_MAX_TOKENS,
}
)
async def validate_input(hass: HomeAssistant, data: dict[str, Any]) -> None: async def validate_input(hass: HomeAssistant, data: dict[str, Any]) -> None:
"""Validate the user input allows us to connect. """Validate the user input allows us to connect.
@ -99,7 +91,9 @@ class GoogleGenerativeAIConfigFlow(ConfigFlow, domain=DOMAIN):
errors["base"] = "unknown" errors["base"] = "unknown"
else: else:
return self.async_create_entry( return self.async_create_entry(
title="Google Generative AI Conversation", data=user_input title="Google Generative AI",
data=user_input,
options={CONF_LLM_HASS_API: llm.LLM_API_ASSIST},
) )
return self.async_show_form( return self.async_show_form(
@ -126,53 +120,96 @@ class GoogleGenerativeAIOptionsFlow(OptionsFlow):
) -> ConfigFlowResult: ) -> ConfigFlowResult:
"""Manage the options.""" """Manage the options."""
if user_input is not None: if user_input is not None:
return self.async_create_entry( if user_input[CONF_LLM_HASS_API] == "none":
title="Google Generative AI Conversation", data=user_input user_input.pop(CONF_LLM_HASS_API)
return self.async_create_entry(title="", data=user_input)
schema = await google_generative_ai_config_option_schema(
self.hass, self.config_entry.options
) )
schema = google_generative_ai_config_option_schema(self.config_entry.options)
return self.async_show_form( return self.async_show_form(
step_id="init", step_id="init",
data_schema=vol.Schema(schema), data_schema=vol.Schema(schema),
) )
def google_generative_ai_config_option_schema( async def google_generative_ai_config_option_schema(
hass: HomeAssistant,
options: MappingProxyType[str, Any], options: MappingProxyType[str, Any],
) -> dict: ) -> dict:
"""Return a schema for Google Generative AI completion options.""" """Return a schema for Google Generative AI completion options."""
if not options: api_models = await hass.async_add_executor_job(partial(genai.list_models))
options = DEFAULT_OPTIONS
models: list[SelectOptionDict] = [
SelectOptionDict(
label="Gemini 1.5 Flash (recommended)",
value="models/gemini-1.5-flash-latest",
),
]
models.extend(
SelectOptionDict(
label=api_model.display_name,
value=api_model.name,
)
for api_model in sorted(api_models, key=lambda x: x.display_name)
if (
api_model.name
not in (
"models/gemini-1.0-pro", # duplicate of gemini-pro
"models/gemini-1.5-flash-latest",
)
and "vision" not in api_model.name
and "generateContent" in api_model.supported_generation_methods
)
)
apis: list[SelectOptionDict] = [
SelectOptionDict(
label="No control",
value="none",
)
]
apis.extend(
SelectOptionDict(
label=api.name,
value=api.id,
)
for api in llm.async_get_apis(hass)
)
return { return {
vol.Optional(
CONF_CHAT_MODEL,
description={"suggested_value": options.get(CONF_CHAT_MODEL)},
default=DEFAULT_CHAT_MODEL,
): SelectSelector(SelectSelectorConfig(options=models)),
vol.Optional(
CONF_LLM_HASS_API,
description={"suggested_value": options.get(CONF_LLM_HASS_API)},
default="none",
): SelectSelector(SelectSelectorConfig(options=apis)),
vol.Optional( vol.Optional(
CONF_PROMPT, CONF_PROMPT,
description={"suggested_value": options[CONF_PROMPT]}, description={"suggested_value": options.get(CONF_PROMPT)},
default=DEFAULT_PROMPT, default=DEFAULT_PROMPT,
): TemplateSelector(), ): TemplateSelector(),
vol.Optional(
CONF_CHAT_MODEL,
description={
"suggested_value": options.get(CONF_CHAT_MODEL, DEFAULT_CHAT_MODEL)
},
default=DEFAULT_CHAT_MODEL,
): str,
vol.Optional( vol.Optional(
CONF_TEMPERATURE, CONF_TEMPERATURE,
description={"suggested_value": options[CONF_TEMPERATURE]}, description={"suggested_value": options.get(CONF_TEMPERATURE)},
default=DEFAULT_TEMPERATURE, default=DEFAULT_TEMPERATURE,
): NumberSelector(NumberSelectorConfig(min=0, max=1, step=0.05)), ): NumberSelector(NumberSelectorConfig(min=0, max=1, step=0.05)),
vol.Optional( vol.Optional(
CONF_TOP_P, CONF_TOP_P,
description={"suggested_value": options[CONF_TOP_P]}, description={"suggested_value": options.get(CONF_TOP_P)},
default=DEFAULT_TOP_P, default=DEFAULT_TOP_P,
): NumberSelector(NumberSelectorConfig(min=0, max=1, step=0.05)), ): NumberSelector(NumberSelectorConfig(min=0, max=1, step=0.05)),
vol.Optional( vol.Optional(
CONF_TOP_K, CONF_TOP_K,
description={"suggested_value": options[CONF_TOP_K]}, description={"suggested_value": options.get(CONF_TOP_K)},
default=DEFAULT_TOP_K, default=DEFAULT_TOP_K,
): int, ): int,
vol.Optional( vol.Optional(
CONF_MAX_TOKENS, CONF_MAX_TOKENS,
description={"suggested_value": options[CONF_MAX_TOKENS]}, description={"suggested_value": options.get(CONF_MAX_TOKENS)},
default=DEFAULT_MAX_TOKENS, default=DEFAULT_MAX_TOKENS,
): int, ): int,
} }

View File

@ -21,11 +21,8 @@ An overview of the areas and the devices in this smart home:
{%- endif %} {%- endif %}
{%- endfor %} {%- endfor %}
{%- endfor %} {%- endfor %}
Answer the user's questions about the world truthfully.
If the user wants to control a device, reject the request and suggest using the Home Assistant app.
""" """
CONF_CHAT_MODEL = "chat_model" CONF_CHAT_MODEL = "chat_model"
DEFAULT_CHAT_MODEL = "models/gemini-pro" DEFAULT_CHAT_MODEL = "models/gemini-pro"
CONF_TEMPERATURE = "temperature" CONF_TEMPERATURE = "temperature"
@ -36,3 +33,4 @@ CONF_TOP_K = "top_k"
DEFAULT_TOP_K = 1 DEFAULT_TOP_K = 1
CONF_MAX_TOKENS = "max_tokens" CONF_MAX_TOKENS = "max_tokens"
DEFAULT_MAX_TOKENS = 150 DEFAULT_MAX_TOKENS = 150
DEFAULT_ALLOW_HASS_ACCESS = False

View File

@ -2,18 +2,21 @@
from __future__ import annotations from __future__ import annotations
from typing import Literal from typing import Any, Literal
import google.ai.generativelanguage as glm
from google.api_core.exceptions import ClientError from google.api_core.exceptions import ClientError
import google.generativeai as genai import google.generativeai as genai
import google.generativeai.types as genai_types import google.generativeai.types as genai_types
import voluptuous as vol
from voluptuous_openapi import convert
from homeassistant.components import assist_pipeline, conversation from homeassistant.components import assist_pipeline, conversation
from homeassistant.config_entries import ConfigEntry from homeassistant.config_entries import ConfigEntry
from homeassistant.const import MATCH_ALL from homeassistant.const import CONF_LLM_HASS_API, MATCH_ALL
from homeassistant.core import HomeAssistant from homeassistant.core import HomeAssistant
from homeassistant.exceptions import TemplateError from homeassistant.exceptions import HomeAssistantError, TemplateError
from homeassistant.helpers import intent, template from homeassistant.helpers import intent, llm, template
from homeassistant.helpers.entity_platform import AddEntitiesCallback from homeassistant.helpers.entity_platform import AddEntitiesCallback
from homeassistant.util import ulid from homeassistant.util import ulid
@ -30,9 +33,13 @@ from .const import (
DEFAULT_TEMPERATURE, DEFAULT_TEMPERATURE,
DEFAULT_TOP_K, DEFAULT_TOP_K,
DEFAULT_TOP_P, DEFAULT_TOP_P,
DOMAIN,
LOGGER, LOGGER,
) )
# Max number of back and forth with the LLM to generate a response
MAX_TOOL_ITERATIONS = 10
async def async_setup_entry( async def async_setup_entry(
hass: HomeAssistant, hass: HomeAssistant,
@ -44,6 +51,55 @@ async def async_setup_entry(
async_add_entities([agent]) async_add_entities([agent])
SUPPORTED_SCHEMA_KEYS = {
"type",
"format",
"description",
"nullable",
"enum",
"items",
"properties",
"required",
}
def _format_schema(schema: dict[str, Any]) -> dict[str, Any]:
"""Format the schema to protobuf."""
result = {}
for key, val in schema.items():
if key not in SUPPORTED_SCHEMA_KEYS:
continue
if key == "type":
key = "type_"
val = val.upper()
elif key == "format":
key = "format_"
elif key == "items":
val = _format_schema(val)
elif key == "properties":
val = {k: _format_schema(v) for k, v in val.items()}
result[key] = val
return result
def _format_tool(tool: llm.Tool) -> dict[str, Any]:
"""Format tool specification."""
parameters = _format_schema(convert(tool.parameters))
return glm.Tool(
{
"function_declarations": [
{
"name": tool.name,
"description": tool.description,
"parameters": parameters,
}
]
}
)
class GoogleGenerativeAIConversationEntity( class GoogleGenerativeAIConversationEntity(
conversation.ConversationEntity, conversation.AbstractConversationAgent conversation.ConversationEntity, conversation.AbstractConversationAgent
): ):
@ -80,6 +136,26 @@ class GoogleGenerativeAIConversationEntity(
self, user_input: conversation.ConversationInput self, user_input: conversation.ConversationInput
) -> conversation.ConversationResult: ) -> conversation.ConversationResult:
"""Process a sentence.""" """Process a sentence."""
intent_response = intent.IntentResponse(language=user_input.language)
llm_api: llm.API | None = None
tools: list[dict[str, Any]] | None = None
if self.entry.options.get(CONF_LLM_HASS_API):
try:
llm_api = llm.async_get_api(
self.hass, self.entry.options[CONF_LLM_HASS_API]
)
except HomeAssistantError as err:
LOGGER.error("Error getting LLM API: %s", err)
intent_response.async_set_error(
intent.IntentResponseErrorCode.UNKNOWN,
f"Error preparing LLM API: {err}",
)
return conversation.ConversationResult(
response=intent_response, conversation_id=user_input.conversation_id
)
tools = [_format_tool(tool) for tool in llm_api.async_get_tools()]
raw_prompt = self.entry.options.get(CONF_PROMPT, DEFAULT_PROMPT) raw_prompt = self.entry.options.get(CONF_PROMPT, DEFAULT_PROMPT)
model = genai.GenerativeModel( model = genai.GenerativeModel(
model_name=self.entry.options.get(CONF_CHAT_MODEL, DEFAULT_CHAT_MODEL), model_name=self.entry.options.get(CONF_CHAT_MODEL, DEFAULT_CHAT_MODEL),
@ -93,8 +169,8 @@ class GoogleGenerativeAIConversationEntity(
CONF_MAX_TOKENS, DEFAULT_MAX_TOKENS CONF_MAX_TOKENS, DEFAULT_MAX_TOKENS
), ),
}, },
tools=tools or None,
) )
LOGGER.debug("Model: %s", model)
if user_input.conversation_id in self.history: if user_input.conversation_id in self.history:
conversation_id = user_input.conversation_id conversation_id = user_input.conversation_id
@ -103,9 +179,8 @@ class GoogleGenerativeAIConversationEntity(
conversation_id = ulid.ulid_now() conversation_id = ulid.ulid_now()
messages = [{}, {}] messages = [{}, {}]
intent_response = intent.IntentResponse(language=user_input.language)
try: try:
prompt = self._async_generate_prompt(raw_prompt) prompt = self._async_generate_prompt(raw_prompt, llm_api)
except TemplateError as err: except TemplateError as err:
LOGGER.error("Error rendering prompt: %s", err) LOGGER.error("Error rendering prompt: %s", err)
intent_response.async_set_error( intent_response.async_set_error(
@ -122,8 +197,11 @@ class GoogleGenerativeAIConversationEntity(
LOGGER.debug("Input: '%s' with history: %s", user_input.text, messages) LOGGER.debug("Input: '%s' with history: %s", user_input.text, messages)
chat = model.start_chat(history=messages) chat = model.start_chat(history=messages)
chat_request = user_input.text
# To prevent infinite loops, we limit the number of iterations
for _iteration in range(MAX_TOOL_ITERATIONS):
try: try:
chat_response = await chat.send_message_async(user_input.text) chat_response = await chat.send_message_async(chat_request)
except ( except (
ClientError, ClientError,
ValueError, ValueError,
@ -149,13 +227,54 @@ class GoogleGenerativeAIConversationEntity(
response=intent_response, conversation_id=conversation_id response=intent_response, conversation_id=conversation_id
) )
self.history[conversation_id] = chat.history self.history[conversation_id] = chat.history
tool_call = chat_response.parts[0].function_call
if not tool_call or not llm_api:
break
tool_input = llm.ToolInput(
tool_name=tool_call.name,
tool_args=dict(tool_call.args),
platform=DOMAIN,
context=user_input.context,
user_prompt=user_input.text,
language=user_input.language,
assistant=conversation.DOMAIN,
)
LOGGER.debug(
"Tool call: %s(%s)", tool_input.tool_name, tool_input.tool_args
)
try:
function_response = await llm_api.async_call_tool(tool_input)
except (HomeAssistantError, vol.Invalid) as e:
function_response = {"error": type(e).__name__}
if str(e):
function_response["error_text"] = str(e)
LOGGER.debug("Tool response: %s", function_response)
chat_request = glm.Content(
parts=[
glm.Part(
function_response=glm.FunctionResponse(
name=tool_call.name, response=function_response
)
)
]
)
intent_response.async_set_speech(chat_response.text) intent_response.async_set_speech(chat_response.text)
return conversation.ConversationResult( return conversation.ConversationResult(
response=intent_response, conversation_id=conversation_id response=intent_response, conversation_id=conversation_id
) )
def _async_generate_prompt(self, raw_prompt: str) -> str: def _async_generate_prompt(self, raw_prompt: str, llm_api: llm.API | None) -> str:
"""Generate a prompt for the user.""" """Generate a prompt for the user."""
raw_prompt += "\n"
if llm_api:
raw_prompt += llm_api.prompt_template
else:
raw_prompt += llm.PROMPT_NO_API_CONFIGURED
return template.Template(raw_prompt, self.hass).async_render( return template.Template(raw_prompt, self.hass).async_render(
{ {
"ha_name": self.hass.config.location_name, "ha_name": self.hass.config.location_name,

View File

@ -1,12 +1,12 @@
{ {
"domain": "google_generative_ai_conversation", "domain": "google_generative_ai_conversation",
"name": "Google Generative AI Conversation", "name": "Google Generative AI",
"after_dependencies": ["assist_pipeline"], "after_dependencies": ["assist_pipeline", "intent"],
"codeowners": ["@tronikos"], "codeowners": ["@tronikos"],
"config_flow": true, "config_flow": true,
"dependencies": ["conversation"], "dependencies": ["conversation"],
"documentation": "https://www.home-assistant.io/integrations/google_generative_ai_conversation", "documentation": "https://www.home-assistant.io/integrations/google_generative_ai_conversation",
"integration_type": "service", "integration_type": "service",
"iot_class": "cloud_polling", "iot_class": "cloud_polling",
"requirements": ["google-generativeai==0.5.4"] "requirements": ["google-generativeai==0.5.4", "voluptuous-openapi==0.0.3"]
} }

View File

@ -3,7 +3,8 @@
"step": { "step": {
"user": { "user": {
"data": { "data": {
"api_key": "[%key:common::config_flow::data::api_key%]" "api_key": "[%key:common::config_flow::data::api_key%]",
"llm_hass_api": "[%key:common::config_flow::data::llm_hass_api%]"
} }
} }
}, },
@ -18,11 +19,12 @@
"init": { "init": {
"data": { "data": {
"prompt": "Prompt Template", "prompt": "Prompt Template",
"model": "[%key:common::generic::model%]", "chat_model": "[%key:common::generic::model%]",
"temperature": "Temperature", "temperature": "Temperature",
"top_p": "Top P", "top_p": "Top P",
"top_k": "Top K", "top_k": "Top K",
"max_tokens": "Maximum tokens to return in response" "max_tokens": "Maximum tokens to return in response",
"llm_hass_api": "[%key:common::config_flow::data::llm_hass_api%]"
} }
} }
} }

View File

@ -113,6 +113,7 @@ CONF_ACCESS_TOKEN: Final = "access_token"
CONF_ADDRESS: Final = "address" CONF_ADDRESS: Final = "address"
CONF_AFTER: Final = "after" CONF_AFTER: Final = "after"
CONF_ALIAS: Final = "alias" CONF_ALIAS: Final = "alias"
CONF_LLM_HASS_API = "llm_hass_api"
CONF_ALLOWLIST_EXTERNAL_URLS: Final = "allowlist_external_urls" CONF_ALLOWLIST_EXTERNAL_URLS: Final = "allowlist_external_urls"
CONF_API_KEY: Final = "api_key" CONF_API_KEY: Final = "api_key"
CONF_API_TOKEN: Final = "api_token" CONF_API_TOKEN: Final = "api_token"

View File

@ -2271,7 +2271,7 @@
"integration_type": "service", "integration_type": "service",
"config_flow": true, "config_flow": true,
"iot_class": "cloud_polling", "iot_class": "cloud_polling",
"name": "Google Generative AI Conversation" "name": "Google Generative AI"
}, },
"google_mail": { "google_mail": {
"integration_type": "service", "integration_type": "service",

View File

@ -17,18 +17,17 @@ from homeassistant.util.json import JsonObjectType
from . import intent from . import intent
from .singleton import singleton from .singleton import singleton
LLM_API_ASSIST = "assist"
PROMPT_NO_API_CONFIGURED = "If the user wants to control a device, tell them to edit the AI configuration and allow access to Home Assistant."
@singleton("llm") @singleton("llm")
@callback @callback
def _async_get_apis(hass: HomeAssistant) -> dict[str, API]: def _async_get_apis(hass: HomeAssistant) -> dict[str, API]:
"""Get all the LLM APIs.""" """Get all the LLM APIs."""
return { return {
"assist": AssistAPI( LLM_API_ASSIST: AssistAPI(hass=hass),
hass=hass,
id="assist",
name="Assist",
prompt_template="Call the intent tools to control the system. Just pass the name to the intent.",
),
} }
@ -170,6 +169,15 @@ class AssistAPI(API):
INTENT_GET_TEMPERATURE, INTENT_GET_TEMPERATURE,
} }
def __init__(self, hass: HomeAssistant) -> None:
"""Init the class."""
super().__init__(
hass=hass,
id=LLM_API_ASSIST,
name="Assist",
prompt_template="Call the intent tools to control the system. Just pass the name to the intent.",
)
@callback @callback
def async_get_tools(self) -> list[Tool]: def async_get_tools(self) -> list[Tool]:
"""Return a list of LLM tools.""" """Return a list of LLM tools."""

View File

@ -88,6 +88,7 @@
"access_token": "Access token", "access_token": "Access token",
"api_key": "API key", "api_key": "API key",
"api_token": "API token", "api_token": "API token",
"llm_hass_api": "Control Home Assistant",
"ssl": "Uses an SSL certificate", "ssl": "Uses an SSL certificate",
"verify_ssl": "Verify SSL certificate", "verify_ssl": "Verify SSL certificate",
"elevation": "Elevation", "elevation": "Elevation",

View File

@ -2825,6 +2825,9 @@ voip-utils==0.1.0
# homeassistant.components.volkszaehler # homeassistant.components.volkszaehler
volkszaehler==0.4.0 volkszaehler==0.4.0
# homeassistant.components.google_generative_ai_conversation
voluptuous-openapi==0.0.3
# homeassistant.components.volvooncall # homeassistant.components.volvooncall
volvooncall==0.10.3 volvooncall==0.10.3

View File

@ -2190,6 +2190,9 @@ vilfo-api-client==0.5.0
# homeassistant.components.voip # homeassistant.components.voip
voip-utils==0.1.0 voip-utils==0.1.0
# homeassistant.components.google_generative_ai_conversation
voluptuous-openapi==0.0.3
# homeassistant.components.volvooncall # homeassistant.components.volvooncall
volvooncall==0.10.3 volvooncall==0.10.3

View File

@ -5,7 +5,9 @@ from unittest.mock import patch
import pytest import pytest
from homeassistant.config_entries import ConfigEntry from homeassistant.config_entries import ConfigEntry
from homeassistant.const import CONF_LLM_HASS_API
from homeassistant.core import HomeAssistant from homeassistant.core import HomeAssistant
from homeassistant.helpers import llm
from homeassistant.setup import async_setup_component from homeassistant.setup import async_setup_component
from tests.common import MockConfigEntry from tests.common import MockConfigEntry
@ -25,6 +27,15 @@ def mock_config_entry(hass):
return entry return entry
@pytest.fixture
def mock_config_entry_with_assist(hass, mock_config_entry):
"""Mock a config entry with assist."""
hass.config_entries.async_update_entry(
mock_config_entry, options={CONF_LLM_HASS_API: llm.LLM_API_ASSIST}
)
return mock_config_entry
@pytest.fixture @pytest.fixture
async def mock_init_component(hass: HomeAssistant, mock_config_entry: ConfigEntry): async def mock_init_component(hass: HomeAssistant, mock_config_entry: ConfigEntry):
"""Initialize integration.""" """Initialize integration."""

View File

@ -1,5 +1,5 @@
# serializer version: 1 # serializer version: 1
# name: test_default_prompt[None] # name: test_default_prompt[False-None]
list([ list([
tuple( tuple(
'', '',
@ -13,6 +13,7 @@
'top_p': 1.0, 'top_p': 1.0,
}), }),
'model_name': 'models/gemini-pro', 'model_name': 'models/gemini-pro',
'tools': None,
}), }),
), ),
tuple( tuple(
@ -36,9 +37,7 @@
- Test Device 4 - Test Device 4
- 1 (3) - 1 (3)
Answer the user's questions about the world truthfully. If the user wants to control a device, tell them to edit the AI configuration and allow access to Home Assistant.
If the user wants to control a device, reject the request and suggest using the Home Assistant app.
''', ''',
'role': 'user', 'role': 'user',
}), }),
@ -59,7 +58,7 @@
), ),
]) ])
# --- # ---
# name: test_default_prompt[conversation.google_generative_ai_conversation] # name: test_default_prompt[False-conversation.google_generative_ai_conversation]
list([ list([
tuple( tuple(
'', '',
@ -73,6 +72,7 @@
'top_p': 1.0, 'top_p': 1.0,
}), }),
'model_name': 'models/gemini-pro', 'model_name': 'models/gemini-pro',
'tools': None,
}), }),
), ),
tuple( tuple(
@ -96,9 +96,7 @@
- Test Device 4 - Test Device 4
- 1 (3) - 1 (3)
Answer the user's questions about the world truthfully. If the user wants to control a device, tell them to edit the AI configuration and allow access to Home Assistant.
If the user wants to control a device, reject the request and suggest using the Home Assistant app.
''', ''',
'role': 'user', 'role': 'user',
}), }),
@ -119,48 +117,118 @@
), ),
]) ])
# --- # ---
# name: test_generate_content_service_with_image # name: test_default_prompt[True-None]
list([ list([
tuple( tuple(
'', '',
tuple( tuple(
), ),
dict({ dict({
'model_name': 'gemini-pro-vision', 'generation_config': dict({
'max_output_tokens': 150,
'temperature': 0.9,
'top_k': 1,
'top_p': 1.0,
}),
'model_name': 'models/gemini-pro',
'tools': None,
}), }),
), ),
tuple( tuple(
'().generate_content_async', '().start_chat',
tuple( tuple(
list([ ),
'Describe this image from my doorbell camera',
dict({ dict({
'data': b'image bytes', 'history': list([
'mime_type': 'image/jpeg', dict({
'parts': '''
This smart home is controlled by Home Assistant.
An overview of the areas and the devices in this smart home:
Test Area:
- Test Device (Test Model)
Test Area 2:
- Test Device 2
- Test Device 3 (Test Model 3A)
- Test Device 4
- 1 (3)
Call the intent tools to control the system. Just pass the name to the intent.
''',
'role': 'user',
}),
dict({
'parts': 'Ok',
'role': 'model',
}), }),
]), ]),
}),
),
tuple(
'().start_chat().send_message_async',
tuple(
'hello',
), ),
dict({ dict({
}), }),
), ),
]) ])
# --- # ---
# name: test_generate_content_service_without_images # name: test_default_prompt[True-conversation.google_generative_ai_conversation]
list([ list([
tuple( tuple(
'', '',
tuple( tuple(
), ),
dict({ dict({
'model_name': 'gemini-pro', 'generation_config': dict({
'max_output_tokens': 150,
'temperature': 0.9,
'top_k': 1,
'top_p': 1.0,
}),
'model_name': 'models/gemini-pro',
'tools': None,
}), }),
), ),
tuple( tuple(
'().generate_content_async', '().start_chat',
tuple( tuple(
list([ ),
'Write an opening speech for a Home Assistant release party', dict({
'history': list([
dict({
'parts': '''
This smart home is controlled by Home Assistant.
An overview of the areas and the devices in this smart home:
Test Area:
- Test Device (Test Model)
Test Area 2:
- Test Device 2
- Test Device 3 (Test Model 3A)
- Test Device 4
- 1 (3)
Call the intent tools to control the system. Just pass the name to the intent.
''',
'role': 'user',
}),
dict({
'parts': 'Ok',
'role': 'model',
}),
]), ]),
}),
),
tuple(
'().start_chat().send_message_async',
tuple(
'hello',
), ),
dict({ dict({
}), }),

View File

@ -1,6 +1,6 @@
"""Test the Google Generative AI Conversation config flow.""" """Test the Google Generative AI Conversation config flow."""
from unittest.mock import patch from unittest.mock import Mock, patch
from google.api_core.exceptions import ClientError from google.api_core.exceptions import ClientError
from google.rpc.error_details_pb2 import ErrorInfo from google.rpc.error_details_pb2 import ErrorInfo
@ -18,12 +18,35 @@ from homeassistant.components.google_generative_ai_conversation.const import (
DEFAULT_TOP_P, DEFAULT_TOP_P,
DOMAIN, DOMAIN,
) )
from homeassistant.const import CONF_LLM_HASS_API
from homeassistant.core import HomeAssistant from homeassistant.core import HomeAssistant
from homeassistant.data_entry_flow import FlowResultType from homeassistant.data_entry_flow import FlowResultType
from homeassistant.helpers import llm
from tests.common import MockConfigEntry from tests.common import MockConfigEntry
@pytest.fixture
def mock_models():
"""Mock the model list API."""
model_15_flash = Mock(
display_name="Gemini 1.5 Flash",
supported_generation_methods=["generateContent"],
)
model_15_flash.name = "models/gemini-1.5-flash-latest"
model_10_pro = Mock(
display_name="Gemini 1.0 Pro",
supported_generation_methods=["generateContent"],
)
model_10_pro.name = "models/gemini-pro"
with patch(
"homeassistant.components.google_generative_ai_conversation.config_flow.genai.list_models",
return_value=[model_10_pro],
):
yield
async def test_form(hass: HomeAssistant) -> None: async def test_form(hass: HomeAssistant) -> None:
"""Test we get the form.""" """Test we get the form."""
# Pretend we already set up a config entry. # Pretend we already set up a config entry.
@ -60,11 +83,14 @@ async def test_form(hass: HomeAssistant) -> None:
assert result2["data"] == { assert result2["data"] == {
"api_key": "bla", "api_key": "bla",
} }
assert result2["options"] == {
CONF_LLM_HASS_API: llm.LLM_API_ASSIST,
}
assert len(mock_setup_entry.mock_calls) == 1 assert len(mock_setup_entry.mock_calls) == 1
async def test_options( async def test_options(
hass: HomeAssistant, mock_config_entry, mock_init_component hass: HomeAssistant, mock_config_entry, mock_init_component, mock_models
) -> None: ) -> None:
"""Test the options form.""" """Test the options form."""
options_flow = await hass.config_entries.options.async_init( options_flow = await hass.config_entries.options.async_init(
@ -85,6 +111,9 @@ async def test_options(
assert options["data"][CONF_TOP_P] == DEFAULT_TOP_P assert options["data"][CONF_TOP_P] == DEFAULT_TOP_P
assert options["data"][CONF_TOP_K] == DEFAULT_TOP_K assert options["data"][CONF_TOP_K] == DEFAULT_TOP_K
assert options["data"][CONF_MAX_TOKENS] == DEFAULT_MAX_TOKENS assert options["data"][CONF_MAX_TOKENS] == DEFAULT_MAX_TOKENS
assert (
CONF_LLM_HASS_API not in options["data"]
), "Options flow should not set this key"
@pytest.mark.parametrize( @pytest.mark.parametrize(

View File

@ -5,10 +5,18 @@ from unittest.mock import AsyncMock, MagicMock, patch
from google.api_core.exceptions import ClientError from google.api_core.exceptions import ClientError
import pytest import pytest
from syrupy.assertion import SnapshotAssertion from syrupy.assertion import SnapshotAssertion
import voluptuous as vol
from homeassistant.components import conversation from homeassistant.components import conversation
from homeassistant.const import CONF_LLM_HASS_API
from homeassistant.core import Context, HomeAssistant from homeassistant.core import Context, HomeAssistant
from homeassistant.helpers import area_registry as ar, device_registry as dr, intent from homeassistant.exceptions import HomeAssistantError
from homeassistant.helpers import (
area_registry as ar,
device_registry as dr,
intent,
llm,
)
from tests.common import MockConfigEntry from tests.common import MockConfigEntry
@ -16,6 +24,7 @@ from tests.common import MockConfigEntry
@pytest.mark.parametrize( @pytest.mark.parametrize(
"agent_id", [None, "conversation.google_generative_ai_conversation"] "agent_id", [None, "conversation.google_generative_ai_conversation"]
) )
@pytest.mark.parametrize("allow_hass_access", [False, True])
async def test_default_prompt( async def test_default_prompt(
hass: HomeAssistant, hass: HomeAssistant,
mock_config_entry: MockConfigEntry, mock_config_entry: MockConfigEntry,
@ -24,6 +33,7 @@ async def test_default_prompt(
device_registry: dr.DeviceRegistry, device_registry: dr.DeviceRegistry,
snapshot: SnapshotAssertion, snapshot: SnapshotAssertion,
agent_id: str | None, agent_id: str | None,
allow_hass_access: bool,
) -> None: ) -> None:
"""Test that the default prompt works.""" """Test that the default prompt works."""
entry = MockConfigEntry(title=None) entry = MockConfigEntry(title=None)
@ -34,6 +44,15 @@ async def test_default_prompt(
if agent_id is None: if agent_id is None:
agent_id = mock_config_entry.entry_id agent_id = mock_config_entry.entry_id
if allow_hass_access:
hass.config_entries.async_update_entry(
mock_config_entry,
options={
**mock_config_entry.options,
CONF_LLM_HASS_API: llm.LLM_API_ASSIST,
},
)
device_registry.async_get_or_create( device_registry.async_get_or_create(
config_entry_id=entry.entry_id, config_entry_id=entry.entry_id,
connections={("test", "1234")}, connections={("test", "1234")},
@ -100,12 +119,20 @@ async def test_default_prompt(
model=3, model=3,
suggested_area="Test Area 2", suggested_area="Test Area 2",
) )
with patch("google.generativeai.GenerativeModel") as mock_model: with (
patch("google.generativeai.GenerativeModel") as mock_model,
patch(
"homeassistant.components.google_generative_ai_conversation.conversation.llm.AssistAPI.async_get_tools",
return_value=[],
) as mock_get_tools,
):
mock_chat = AsyncMock() mock_chat = AsyncMock()
mock_model.return_value.start_chat.return_value = mock_chat mock_model.return_value.start_chat.return_value = mock_chat
chat_response = MagicMock() chat_response = MagicMock()
mock_chat.send_message_async.return_value = chat_response mock_chat.send_message_async.return_value = chat_response
chat_response.parts = ["Hi there!"] mock_part = MagicMock()
mock_part.function_call = None
chat_response.parts = [mock_part]
chat_response.text = "Hi there!" chat_response.text = "Hi there!"
result = await conversation.async_converse( result = await conversation.async_converse(
hass, hass,
@ -118,6 +145,171 @@ async def test_default_prompt(
assert result.response.response_type == intent.IntentResponseType.ACTION_DONE assert result.response.response_type == intent.IntentResponseType.ACTION_DONE
assert result.response.as_dict()["speech"]["plain"]["speech"] == "Hi there!" assert result.response.as_dict()["speech"]["plain"]["speech"] == "Hi there!"
assert [tuple(mock_call) for mock_call in mock_model.mock_calls] == snapshot assert [tuple(mock_call) for mock_call in mock_model.mock_calls] == snapshot
assert mock_get_tools.called == allow_hass_access
@patch(
"homeassistant.components.google_generative_ai_conversation.conversation.llm.AssistAPI.async_get_tools"
)
async def test_function_call(
mock_get_tools,
hass: HomeAssistant,
mock_config_entry_with_assist: MockConfigEntry,
mock_init_component,
) -> None:
"""Test that the default prompt works."""
agent_id = mock_config_entry_with_assist.entry_id
context = Context()
mock_tool = AsyncMock()
mock_tool.name = "test_tool"
mock_tool.description = "Test function"
mock_tool.parameters = vol.Schema(
{
vol.Optional("param1", description="Test parameters"): [
vol.All(str, vol.Lower)
]
}
)
mock_get_tools.return_value = [mock_tool]
with patch("google.generativeai.GenerativeModel") as mock_model:
mock_chat = AsyncMock()
mock_model.return_value.start_chat.return_value = mock_chat
chat_response = MagicMock()
mock_chat.send_message_async.return_value = chat_response
mock_part = MagicMock()
mock_part.function_call.name = "test_tool"
mock_part.function_call.args = {"param1": ["test_value"]}
def tool_call(hass, tool_input):
mock_part.function_call = False
chat_response.text = "Hi there!"
return {"result": "Test response"}
mock_tool.async_call.side_effect = tool_call
chat_response.parts = [mock_part]
result = await conversation.async_converse(
hass,
"Please call the test function",
None,
context,
agent_id=agent_id,
)
assert result.response.response_type == intent.IntentResponseType.ACTION_DONE
assert result.response.as_dict()["speech"]["plain"]["speech"] == "Hi there!"
mock_tool_call = mock_chat.send_message_async.mock_calls[1][1][0]
mock_tool_call = type(mock_tool_call).to_dict(mock_tool_call)
assert mock_tool_call == {
"parts": [
{
"function_response": {
"name": "test_tool",
"response": {
"result": "Test response",
},
},
},
],
"role": "",
}
mock_tool.async_call.assert_awaited_once_with(
hass,
llm.ToolInput(
tool_name="test_tool",
tool_args={"param1": ["test_value"]},
platform="google_generative_ai_conversation",
context=context,
user_prompt="Please call the test function",
language="en",
assistant="conversation",
),
)
@patch(
"homeassistant.components.google_generative_ai_conversation.conversation.llm.AssistAPI.async_get_tools"
)
async def test_function_exception(
mock_get_tools,
hass: HomeAssistant,
mock_config_entry_with_assist: MockConfigEntry,
mock_init_component,
) -> None:
"""Test that the default prompt works."""
agent_id = mock_config_entry_with_assist.entry_id
context = Context()
mock_tool = AsyncMock()
mock_tool.name = "test_tool"
mock_tool.description = "Test function"
mock_tool.parameters = vol.Schema(
{
vol.Optional("param1", description="Test parameters"): vol.All(
vol.Coerce(int), vol.Range(0, 100)
)
}
)
mock_get_tools.return_value = [mock_tool]
with patch("google.generativeai.GenerativeModel") as mock_model:
mock_chat = AsyncMock()
mock_model.return_value.start_chat.return_value = mock_chat
chat_response = MagicMock()
mock_chat.send_message_async.return_value = chat_response
mock_part = MagicMock()
mock_part.function_call.name = "test_tool"
mock_part.function_call.args = {"param1": 1}
def tool_call(hass, tool_input):
mock_part.function_call = False
chat_response.text = "Hi there!"
raise HomeAssistantError("Test tool exception")
mock_tool.async_call.side_effect = tool_call
chat_response.parts = [mock_part]
result = await conversation.async_converse(
hass,
"Please call the test function",
None,
context,
agent_id=agent_id,
)
assert result.response.response_type == intent.IntentResponseType.ACTION_DONE
assert result.response.as_dict()["speech"]["plain"]["speech"] == "Hi there!"
mock_tool_call = mock_chat.send_message_async.mock_calls[1][1][0]
mock_tool_call = type(mock_tool_call).to_dict(mock_tool_call)
assert mock_tool_call == {
"parts": [
{
"function_response": {
"name": "test_tool",
"response": {
"error": "HomeAssistantError",
"error_text": "Test tool exception",
},
},
},
],
"role": "",
}
mock_tool.async_call.assert_awaited_once_with(
hass,
llm.ToolInput(
tool_name="test_tool",
tool_args={"param1": 1},
platform="google_generative_ai_conversation",
context=context,
user_prompt="Please call the test function",
language="en",
assistant="conversation",
),
)
async def test_error_handling( async def test_error_handling(

View File

@ -18,12 +18,13 @@ async def test_get_api_no_existing(hass: HomeAssistant) -> None:
async def test_register_api(hass: HomeAssistant) -> None: async def test_register_api(hass: HomeAssistant) -> None:
"""Test registering an llm api.""" """Test registering an llm api."""
api = llm.AssistAPI(
hass=hass, class MyAPI(llm.API):
id="test", def async_get_tools(self) -> list[llm.Tool]:
name="Test", """Return a list of tools."""
prompt_template="Test", return []
)
api = MyAPI(hass=hass, id="test", name="Test", prompt_template="")
llm.async_register_api(hass, api) llm.async_register_api(hass, api)
assert llm.async_get_api(hass, "test") is api assert llm.async_get_api(hass, "test") is api