mirror of
https://github.com/home-assistant/core.git
synced 2025-07-23 13:17:32 +00:00
LLM Tools support for Google Generative AI integration (#117644)
* initial commit * Undo prompt chenges * Move format_tool out of the class * Only catch HomeAssistantError and vol.Invalid * Add config flow option * Fix type * Add translation * Allow changing API access from options flow * Allow model picking * Remove allowing HASS Access in main flow * Move model to the top in options flow * Make prompt conditional based on API access * convert only once to dict * Reduce debug logging * Update title * re-order models * Address comments * Move things * Update labels * Add tool call tests * coverage * Use LLM APIs * Fixes * Address comments * Reinstate the title to not break entity name --------- Co-authored-by: Paulus Schoutsen <balloob@gmail.com>
This commit is contained in:
parent
ac3321cef1
commit
c3196a5667
@ -23,7 +23,7 @@ from homeassistant.exceptions import ConfigEntryNotReady, HomeAssistantError
|
||||
from homeassistant.helpers import config_validation as cv
|
||||
from homeassistant.helpers.typing import ConfigType
|
||||
|
||||
from .const import CONF_CHAT_MODEL, CONF_PROMPT, DEFAULT_CHAT_MODEL, DOMAIN, LOGGER
|
||||
from .const import CONF_PROMPT, DOMAIN, LOGGER
|
||||
|
||||
SERVICE_GENERATE_CONTENT = "generate_content"
|
||||
CONF_IMAGE_FILENAME = "image_filename"
|
||||
@ -97,11 +97,7 @@ async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
|
||||
genai.configure(api_key=entry.data[CONF_API_KEY])
|
||||
|
||||
try:
|
||||
await hass.async_add_executor_job(
|
||||
partial(
|
||||
genai.get_model, entry.options.get(CONF_CHAT_MODEL, DEFAULT_CHAT_MODEL)
|
||||
)
|
||||
)
|
||||
await hass.async_add_executor_job(partial(genai.list_models))
|
||||
except ClientError as err:
|
||||
if err.reason == "API_KEY_INVALID":
|
||||
LOGGER.error("Invalid API key: %s", err)
|
||||
|
@ -4,7 +4,6 @@ from __future__ import annotations
|
||||
|
||||
from functools import partial
|
||||
import logging
|
||||
import types
|
||||
from types import MappingProxyType
|
||||
from typing import Any
|
||||
|
||||
@ -18,11 +17,15 @@ from homeassistant.config_entries import (
|
||||
ConfigFlowResult,
|
||||
OptionsFlow,
|
||||
)
|
||||
from homeassistant.const import CONF_API_KEY
|
||||
from homeassistant.const import CONF_API_KEY, CONF_LLM_HASS_API
|
||||
from homeassistant.core import HomeAssistant
|
||||
from homeassistant.helpers import llm
|
||||
from homeassistant.helpers.selector import (
|
||||
NumberSelector,
|
||||
NumberSelectorConfig,
|
||||
SelectOptionDict,
|
||||
SelectSelector,
|
||||
SelectSelectorConfig,
|
||||
TemplateSelector,
|
||||
)
|
||||
|
||||
@ -50,17 +53,6 @@ STEP_USER_DATA_SCHEMA = vol.Schema(
|
||||
}
|
||||
)
|
||||
|
||||
DEFAULT_OPTIONS = types.MappingProxyType(
|
||||
{
|
||||
CONF_PROMPT: DEFAULT_PROMPT,
|
||||
CONF_CHAT_MODEL: DEFAULT_CHAT_MODEL,
|
||||
CONF_TEMPERATURE: DEFAULT_TEMPERATURE,
|
||||
CONF_TOP_P: DEFAULT_TOP_P,
|
||||
CONF_TOP_K: DEFAULT_TOP_K,
|
||||
CONF_MAX_TOKENS: DEFAULT_MAX_TOKENS,
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
async def validate_input(hass: HomeAssistant, data: dict[str, Any]) -> None:
|
||||
"""Validate the user input allows us to connect.
|
||||
@ -99,7 +91,9 @@ class GoogleGenerativeAIConfigFlow(ConfigFlow, domain=DOMAIN):
|
||||
errors["base"] = "unknown"
|
||||
else:
|
||||
return self.async_create_entry(
|
||||
title="Google Generative AI Conversation", data=user_input
|
||||
title="Google Generative AI",
|
||||
data=user_input,
|
||||
options={CONF_LLM_HASS_API: llm.LLM_API_ASSIST},
|
||||
)
|
||||
|
||||
return self.async_show_form(
|
||||
@ -126,53 +120,96 @@ class GoogleGenerativeAIOptionsFlow(OptionsFlow):
|
||||
) -> ConfigFlowResult:
|
||||
"""Manage the options."""
|
||||
if user_input is not None:
|
||||
return self.async_create_entry(
|
||||
title="Google Generative AI Conversation", data=user_input
|
||||
)
|
||||
schema = google_generative_ai_config_option_schema(self.config_entry.options)
|
||||
if user_input[CONF_LLM_HASS_API] == "none":
|
||||
user_input.pop(CONF_LLM_HASS_API)
|
||||
return self.async_create_entry(title="", data=user_input)
|
||||
schema = await google_generative_ai_config_option_schema(
|
||||
self.hass, self.config_entry.options
|
||||
)
|
||||
return self.async_show_form(
|
||||
step_id="init",
|
||||
data_schema=vol.Schema(schema),
|
||||
)
|
||||
|
||||
|
||||
def google_generative_ai_config_option_schema(
|
||||
async def google_generative_ai_config_option_schema(
|
||||
hass: HomeAssistant,
|
||||
options: MappingProxyType[str, Any],
|
||||
) -> dict:
|
||||
"""Return a schema for Google Generative AI completion options."""
|
||||
if not options:
|
||||
options = DEFAULT_OPTIONS
|
||||
api_models = await hass.async_add_executor_job(partial(genai.list_models))
|
||||
|
||||
models: list[SelectOptionDict] = [
|
||||
SelectOptionDict(
|
||||
label="Gemini 1.5 Flash (recommended)",
|
||||
value="models/gemini-1.5-flash-latest",
|
||||
),
|
||||
]
|
||||
models.extend(
|
||||
SelectOptionDict(
|
||||
label=api_model.display_name,
|
||||
value=api_model.name,
|
||||
)
|
||||
for api_model in sorted(api_models, key=lambda x: x.display_name)
|
||||
if (
|
||||
api_model.name
|
||||
not in (
|
||||
"models/gemini-1.0-pro", # duplicate of gemini-pro
|
||||
"models/gemini-1.5-flash-latest",
|
||||
)
|
||||
and "vision" not in api_model.name
|
||||
and "generateContent" in api_model.supported_generation_methods
|
||||
)
|
||||
)
|
||||
|
||||
apis: list[SelectOptionDict] = [
|
||||
SelectOptionDict(
|
||||
label="No control",
|
||||
value="none",
|
||||
)
|
||||
]
|
||||
apis.extend(
|
||||
SelectOptionDict(
|
||||
label=api.name,
|
||||
value=api.id,
|
||||
)
|
||||
for api in llm.async_get_apis(hass)
|
||||
)
|
||||
|
||||
return {
|
||||
vol.Optional(
|
||||
CONF_CHAT_MODEL,
|
||||
description={"suggested_value": options.get(CONF_CHAT_MODEL)},
|
||||
default=DEFAULT_CHAT_MODEL,
|
||||
): SelectSelector(SelectSelectorConfig(options=models)),
|
||||
vol.Optional(
|
||||
CONF_LLM_HASS_API,
|
||||
description={"suggested_value": options.get(CONF_LLM_HASS_API)},
|
||||
default="none",
|
||||
): SelectSelector(SelectSelectorConfig(options=apis)),
|
||||
vol.Optional(
|
||||
CONF_PROMPT,
|
||||
description={"suggested_value": options[CONF_PROMPT]},
|
||||
description={"suggested_value": options.get(CONF_PROMPT)},
|
||||
default=DEFAULT_PROMPT,
|
||||
): TemplateSelector(),
|
||||
vol.Optional(
|
||||
CONF_CHAT_MODEL,
|
||||
description={
|
||||
"suggested_value": options.get(CONF_CHAT_MODEL, DEFAULT_CHAT_MODEL)
|
||||
},
|
||||
default=DEFAULT_CHAT_MODEL,
|
||||
): str,
|
||||
vol.Optional(
|
||||
CONF_TEMPERATURE,
|
||||
description={"suggested_value": options[CONF_TEMPERATURE]},
|
||||
description={"suggested_value": options.get(CONF_TEMPERATURE)},
|
||||
default=DEFAULT_TEMPERATURE,
|
||||
): NumberSelector(NumberSelectorConfig(min=0, max=1, step=0.05)),
|
||||
vol.Optional(
|
||||
CONF_TOP_P,
|
||||
description={"suggested_value": options[CONF_TOP_P]},
|
||||
description={"suggested_value": options.get(CONF_TOP_P)},
|
||||
default=DEFAULT_TOP_P,
|
||||
): NumberSelector(NumberSelectorConfig(min=0, max=1, step=0.05)),
|
||||
vol.Optional(
|
||||
CONF_TOP_K,
|
||||
description={"suggested_value": options[CONF_TOP_K]},
|
||||
description={"suggested_value": options.get(CONF_TOP_K)},
|
||||
default=DEFAULT_TOP_K,
|
||||
): int,
|
||||
vol.Optional(
|
||||
CONF_MAX_TOKENS,
|
||||
description={"suggested_value": options[CONF_MAX_TOKENS]},
|
||||
description={"suggested_value": options.get(CONF_MAX_TOKENS)},
|
||||
default=DEFAULT_MAX_TOKENS,
|
||||
): int,
|
||||
}
|
||||
|
@ -21,11 +21,8 @@ An overview of the areas and the devices in this smart home:
|
||||
{%- endif %}
|
||||
{%- endfor %}
|
||||
{%- endfor %}
|
||||
|
||||
Answer the user's questions about the world truthfully.
|
||||
|
||||
If the user wants to control a device, reject the request and suggest using the Home Assistant app.
|
||||
"""
|
||||
|
||||
CONF_CHAT_MODEL = "chat_model"
|
||||
DEFAULT_CHAT_MODEL = "models/gemini-pro"
|
||||
CONF_TEMPERATURE = "temperature"
|
||||
@ -36,3 +33,4 @@ CONF_TOP_K = "top_k"
|
||||
DEFAULT_TOP_K = 1
|
||||
CONF_MAX_TOKENS = "max_tokens"
|
||||
DEFAULT_MAX_TOKENS = 150
|
||||
DEFAULT_ALLOW_HASS_ACCESS = False
|
||||
|
@ -2,18 +2,21 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Literal
|
||||
from typing import Any, Literal
|
||||
|
||||
import google.ai.generativelanguage as glm
|
||||
from google.api_core.exceptions import ClientError
|
||||
import google.generativeai as genai
|
||||
import google.generativeai.types as genai_types
|
||||
import voluptuous as vol
|
||||
from voluptuous_openapi import convert
|
||||
|
||||
from homeassistant.components import assist_pipeline, conversation
|
||||
from homeassistant.config_entries import ConfigEntry
|
||||
from homeassistant.const import MATCH_ALL
|
||||
from homeassistant.const import CONF_LLM_HASS_API, MATCH_ALL
|
||||
from homeassistant.core import HomeAssistant
|
||||
from homeassistant.exceptions import TemplateError
|
||||
from homeassistant.helpers import intent, template
|
||||
from homeassistant.exceptions import HomeAssistantError, TemplateError
|
||||
from homeassistant.helpers import intent, llm, template
|
||||
from homeassistant.helpers.entity_platform import AddEntitiesCallback
|
||||
from homeassistant.util import ulid
|
||||
|
||||
@ -30,9 +33,13 @@ from .const import (
|
||||
DEFAULT_TEMPERATURE,
|
||||
DEFAULT_TOP_K,
|
||||
DEFAULT_TOP_P,
|
||||
DOMAIN,
|
||||
LOGGER,
|
||||
)
|
||||
|
||||
# Max number of back and forth with the LLM to generate a response
|
||||
MAX_TOOL_ITERATIONS = 10
|
||||
|
||||
|
||||
async def async_setup_entry(
|
||||
hass: HomeAssistant,
|
||||
@ -44,6 +51,55 @@ async def async_setup_entry(
|
||||
async_add_entities([agent])
|
||||
|
||||
|
||||
SUPPORTED_SCHEMA_KEYS = {
|
||||
"type",
|
||||
"format",
|
||||
"description",
|
||||
"nullable",
|
||||
"enum",
|
||||
"items",
|
||||
"properties",
|
||||
"required",
|
||||
}
|
||||
|
||||
|
||||
def _format_schema(schema: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Format the schema to protobuf."""
|
||||
result = {}
|
||||
for key, val in schema.items():
|
||||
if key not in SUPPORTED_SCHEMA_KEYS:
|
||||
continue
|
||||
if key == "type":
|
||||
key = "type_"
|
||||
val = val.upper()
|
||||
elif key == "format":
|
||||
key = "format_"
|
||||
elif key == "items":
|
||||
val = _format_schema(val)
|
||||
elif key == "properties":
|
||||
val = {k: _format_schema(v) for k, v in val.items()}
|
||||
result[key] = val
|
||||
return result
|
||||
|
||||
|
||||
def _format_tool(tool: llm.Tool) -> dict[str, Any]:
|
||||
"""Format tool specification."""
|
||||
|
||||
parameters = _format_schema(convert(tool.parameters))
|
||||
|
||||
return glm.Tool(
|
||||
{
|
||||
"function_declarations": [
|
||||
{
|
||||
"name": tool.name,
|
||||
"description": tool.description,
|
||||
"parameters": parameters,
|
||||
}
|
||||
]
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
class GoogleGenerativeAIConversationEntity(
|
||||
conversation.ConversationEntity, conversation.AbstractConversationAgent
|
||||
):
|
||||
@ -80,6 +136,26 @@ class GoogleGenerativeAIConversationEntity(
|
||||
self, user_input: conversation.ConversationInput
|
||||
) -> conversation.ConversationResult:
|
||||
"""Process a sentence."""
|
||||
intent_response = intent.IntentResponse(language=user_input.language)
|
||||
llm_api: llm.API | None = None
|
||||
tools: list[dict[str, Any]] | None = None
|
||||
|
||||
if self.entry.options.get(CONF_LLM_HASS_API):
|
||||
try:
|
||||
llm_api = llm.async_get_api(
|
||||
self.hass, self.entry.options[CONF_LLM_HASS_API]
|
||||
)
|
||||
except HomeAssistantError as err:
|
||||
LOGGER.error("Error getting LLM API: %s", err)
|
||||
intent_response.async_set_error(
|
||||
intent.IntentResponseErrorCode.UNKNOWN,
|
||||
f"Error preparing LLM API: {err}",
|
||||
)
|
||||
return conversation.ConversationResult(
|
||||
response=intent_response, conversation_id=user_input.conversation_id
|
||||
)
|
||||
tools = [_format_tool(tool) for tool in llm_api.async_get_tools()]
|
||||
|
||||
raw_prompt = self.entry.options.get(CONF_PROMPT, DEFAULT_PROMPT)
|
||||
model = genai.GenerativeModel(
|
||||
model_name=self.entry.options.get(CONF_CHAT_MODEL, DEFAULT_CHAT_MODEL),
|
||||
@ -93,8 +169,8 @@ class GoogleGenerativeAIConversationEntity(
|
||||
CONF_MAX_TOKENS, DEFAULT_MAX_TOKENS
|
||||
),
|
||||
},
|
||||
tools=tools or None,
|
||||
)
|
||||
LOGGER.debug("Model: %s", model)
|
||||
|
||||
if user_input.conversation_id in self.history:
|
||||
conversation_id = user_input.conversation_id
|
||||
@ -103,9 +179,8 @@ class GoogleGenerativeAIConversationEntity(
|
||||
conversation_id = ulid.ulid_now()
|
||||
messages = [{}, {}]
|
||||
|
||||
intent_response = intent.IntentResponse(language=user_input.language)
|
||||
try:
|
||||
prompt = self._async_generate_prompt(raw_prompt)
|
||||
prompt = self._async_generate_prompt(raw_prompt, llm_api)
|
||||
except TemplateError as err:
|
||||
LOGGER.error("Error rendering prompt: %s", err)
|
||||
intent_response.async_set_error(
|
||||
@ -122,40 +197,84 @@ class GoogleGenerativeAIConversationEntity(
|
||||
LOGGER.debug("Input: '%s' with history: %s", user_input.text, messages)
|
||||
|
||||
chat = model.start_chat(history=messages)
|
||||
try:
|
||||
chat_response = await chat.send_message_async(user_input.text)
|
||||
except (
|
||||
ClientError,
|
||||
ValueError,
|
||||
genai_types.BlockedPromptException,
|
||||
genai_types.StopCandidateException,
|
||||
) as err:
|
||||
LOGGER.error("Error sending message: %s", err)
|
||||
intent_response.async_set_error(
|
||||
intent.IntentResponseErrorCode.UNKNOWN,
|
||||
f"Sorry, I had a problem talking to Google Generative AI: {err}",
|
||||
chat_request = user_input.text
|
||||
# To prevent infinite loops, we limit the number of iterations
|
||||
for _iteration in range(MAX_TOOL_ITERATIONS):
|
||||
try:
|
||||
chat_response = await chat.send_message_async(chat_request)
|
||||
except (
|
||||
ClientError,
|
||||
ValueError,
|
||||
genai_types.BlockedPromptException,
|
||||
genai_types.StopCandidateException,
|
||||
) as err:
|
||||
LOGGER.error("Error sending message: %s", err)
|
||||
intent_response.async_set_error(
|
||||
intent.IntentResponseErrorCode.UNKNOWN,
|
||||
f"Sorry, I had a problem talking to Google Generative AI: {err}",
|
||||
)
|
||||
return conversation.ConversationResult(
|
||||
response=intent_response, conversation_id=conversation_id
|
||||
)
|
||||
|
||||
LOGGER.debug("Response: %s", chat_response.parts)
|
||||
if not chat_response.parts:
|
||||
intent_response.async_set_error(
|
||||
intent.IntentResponseErrorCode.UNKNOWN,
|
||||
"Sorry, I had a problem talking to Google Generative AI. Likely blocked",
|
||||
)
|
||||
return conversation.ConversationResult(
|
||||
response=intent_response, conversation_id=conversation_id
|
||||
)
|
||||
self.history[conversation_id] = chat.history
|
||||
tool_call = chat_response.parts[0].function_call
|
||||
|
||||
if not tool_call or not llm_api:
|
||||
break
|
||||
|
||||
tool_input = llm.ToolInput(
|
||||
tool_name=tool_call.name,
|
||||
tool_args=dict(tool_call.args),
|
||||
platform=DOMAIN,
|
||||
context=user_input.context,
|
||||
user_prompt=user_input.text,
|
||||
language=user_input.language,
|
||||
assistant=conversation.DOMAIN,
|
||||
)
|
||||
return conversation.ConversationResult(
|
||||
response=intent_response, conversation_id=conversation_id
|
||||
LOGGER.debug(
|
||||
"Tool call: %s(%s)", tool_input.tool_name, tool_input.tool_args
|
||||
)
|
||||
try:
|
||||
function_response = await llm_api.async_call_tool(tool_input)
|
||||
except (HomeAssistantError, vol.Invalid) as e:
|
||||
function_response = {"error": type(e).__name__}
|
||||
if str(e):
|
||||
function_response["error_text"] = str(e)
|
||||
|
||||
LOGGER.debug("Tool response: %s", function_response)
|
||||
chat_request = glm.Content(
|
||||
parts=[
|
||||
glm.Part(
|
||||
function_response=glm.FunctionResponse(
|
||||
name=tool_call.name, response=function_response
|
||||
)
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
LOGGER.debug("Response: %s", chat_response.parts)
|
||||
if not chat_response.parts:
|
||||
intent_response.async_set_error(
|
||||
intent.IntentResponseErrorCode.UNKNOWN,
|
||||
"Sorry, I had a problem talking to Google Generative AI. Likely blocked",
|
||||
)
|
||||
return conversation.ConversationResult(
|
||||
response=intent_response, conversation_id=conversation_id
|
||||
)
|
||||
self.history[conversation_id] = chat.history
|
||||
intent_response.async_set_speech(chat_response.text)
|
||||
return conversation.ConversationResult(
|
||||
response=intent_response, conversation_id=conversation_id
|
||||
)
|
||||
|
||||
def _async_generate_prompt(self, raw_prompt: str) -> str:
|
||||
def _async_generate_prompt(self, raw_prompt: str, llm_api: llm.API | None) -> str:
|
||||
"""Generate a prompt for the user."""
|
||||
raw_prompt += "\n"
|
||||
if llm_api:
|
||||
raw_prompt += llm_api.prompt_template
|
||||
else:
|
||||
raw_prompt += llm.PROMPT_NO_API_CONFIGURED
|
||||
|
||||
return template.Template(raw_prompt, self.hass).async_render(
|
||||
{
|
||||
"ha_name": self.hass.config.location_name,
|
||||
|
@ -1,12 +1,12 @@
|
||||
{
|
||||
"domain": "google_generative_ai_conversation",
|
||||
"name": "Google Generative AI Conversation",
|
||||
"after_dependencies": ["assist_pipeline"],
|
||||
"name": "Google Generative AI",
|
||||
"after_dependencies": ["assist_pipeline", "intent"],
|
||||
"codeowners": ["@tronikos"],
|
||||
"config_flow": true,
|
||||
"dependencies": ["conversation"],
|
||||
"documentation": "https://www.home-assistant.io/integrations/google_generative_ai_conversation",
|
||||
"integration_type": "service",
|
||||
"iot_class": "cloud_polling",
|
||||
"requirements": ["google-generativeai==0.5.4"]
|
||||
"requirements": ["google-generativeai==0.5.4", "voluptuous-openapi==0.0.3"]
|
||||
}
|
||||
|
@ -3,7 +3,8 @@
|
||||
"step": {
|
||||
"user": {
|
||||
"data": {
|
||||
"api_key": "[%key:common::config_flow::data::api_key%]"
|
||||
"api_key": "[%key:common::config_flow::data::api_key%]",
|
||||
"llm_hass_api": "[%key:common::config_flow::data::llm_hass_api%]"
|
||||
}
|
||||
}
|
||||
},
|
||||
@ -18,11 +19,12 @@
|
||||
"init": {
|
||||
"data": {
|
||||
"prompt": "Prompt Template",
|
||||
"model": "[%key:common::generic::model%]",
|
||||
"chat_model": "[%key:common::generic::model%]",
|
||||
"temperature": "Temperature",
|
||||
"top_p": "Top P",
|
||||
"top_k": "Top K",
|
||||
"max_tokens": "Maximum tokens to return in response"
|
||||
"max_tokens": "Maximum tokens to return in response",
|
||||
"llm_hass_api": "[%key:common::config_flow::data::llm_hass_api%]"
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -113,6 +113,7 @@ CONF_ACCESS_TOKEN: Final = "access_token"
|
||||
CONF_ADDRESS: Final = "address"
|
||||
CONF_AFTER: Final = "after"
|
||||
CONF_ALIAS: Final = "alias"
|
||||
CONF_LLM_HASS_API = "llm_hass_api"
|
||||
CONF_ALLOWLIST_EXTERNAL_URLS: Final = "allowlist_external_urls"
|
||||
CONF_API_KEY: Final = "api_key"
|
||||
CONF_API_TOKEN: Final = "api_token"
|
||||
|
@ -2271,7 +2271,7 @@
|
||||
"integration_type": "service",
|
||||
"config_flow": true,
|
||||
"iot_class": "cloud_polling",
|
||||
"name": "Google Generative AI Conversation"
|
||||
"name": "Google Generative AI"
|
||||
},
|
||||
"google_mail": {
|
||||
"integration_type": "service",
|
||||
|
@ -17,18 +17,17 @@ from homeassistant.util.json import JsonObjectType
|
||||
from . import intent
|
||||
from .singleton import singleton
|
||||
|
||||
LLM_API_ASSIST = "assist"
|
||||
|
||||
PROMPT_NO_API_CONFIGURED = "If the user wants to control a device, tell them to edit the AI configuration and allow access to Home Assistant."
|
||||
|
||||
|
||||
@singleton("llm")
|
||||
@callback
|
||||
def _async_get_apis(hass: HomeAssistant) -> dict[str, API]:
|
||||
"""Get all the LLM APIs."""
|
||||
return {
|
||||
"assist": AssistAPI(
|
||||
hass=hass,
|
||||
id="assist",
|
||||
name="Assist",
|
||||
prompt_template="Call the intent tools to control the system. Just pass the name to the intent.",
|
||||
),
|
||||
LLM_API_ASSIST: AssistAPI(hass=hass),
|
||||
}
|
||||
|
||||
|
||||
@ -170,6 +169,15 @@ class AssistAPI(API):
|
||||
INTENT_GET_TEMPERATURE,
|
||||
}
|
||||
|
||||
def __init__(self, hass: HomeAssistant) -> None:
|
||||
"""Init the class."""
|
||||
super().__init__(
|
||||
hass=hass,
|
||||
id=LLM_API_ASSIST,
|
||||
name="Assist",
|
||||
prompt_template="Call the intent tools to control the system. Just pass the name to the intent.",
|
||||
)
|
||||
|
||||
@callback
|
||||
def async_get_tools(self) -> list[Tool]:
|
||||
"""Return a list of LLM tools."""
|
||||
|
@ -88,6 +88,7 @@
|
||||
"access_token": "Access token",
|
||||
"api_key": "API key",
|
||||
"api_token": "API token",
|
||||
"llm_hass_api": "Control Home Assistant",
|
||||
"ssl": "Uses an SSL certificate",
|
||||
"verify_ssl": "Verify SSL certificate",
|
||||
"elevation": "Elevation",
|
||||
|
@ -2825,6 +2825,9 @@ voip-utils==0.1.0
|
||||
# homeassistant.components.volkszaehler
|
||||
volkszaehler==0.4.0
|
||||
|
||||
# homeassistant.components.google_generative_ai_conversation
|
||||
voluptuous-openapi==0.0.3
|
||||
|
||||
# homeassistant.components.volvooncall
|
||||
volvooncall==0.10.3
|
||||
|
||||
|
@ -2190,6 +2190,9 @@ vilfo-api-client==0.5.0
|
||||
# homeassistant.components.voip
|
||||
voip-utils==0.1.0
|
||||
|
||||
# homeassistant.components.google_generative_ai_conversation
|
||||
voluptuous-openapi==0.0.3
|
||||
|
||||
# homeassistant.components.volvooncall
|
||||
volvooncall==0.10.3
|
||||
|
||||
|
@ -5,7 +5,9 @@ from unittest.mock import patch
|
||||
import pytest
|
||||
|
||||
from homeassistant.config_entries import ConfigEntry
|
||||
from homeassistant.const import CONF_LLM_HASS_API
|
||||
from homeassistant.core import HomeAssistant
|
||||
from homeassistant.helpers import llm
|
||||
from homeassistant.setup import async_setup_component
|
||||
|
||||
from tests.common import MockConfigEntry
|
||||
@ -25,6 +27,15 @@ def mock_config_entry(hass):
|
||||
return entry
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_config_entry_with_assist(hass, mock_config_entry):
|
||||
"""Mock a config entry with assist."""
|
||||
hass.config_entries.async_update_entry(
|
||||
mock_config_entry, options={CONF_LLM_HASS_API: llm.LLM_API_ASSIST}
|
||||
)
|
||||
return mock_config_entry
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def mock_init_component(hass: HomeAssistant, mock_config_entry: ConfigEntry):
|
||||
"""Initialize integration."""
|
||||
|
@ -1,5 +1,5 @@
|
||||
# serializer version: 1
|
||||
# name: test_default_prompt[None]
|
||||
# name: test_default_prompt[False-None]
|
||||
list([
|
||||
tuple(
|
||||
'',
|
||||
@ -13,6 +13,7 @@
|
||||
'top_p': 1.0,
|
||||
}),
|
||||
'model_name': 'models/gemini-pro',
|
||||
'tools': None,
|
||||
}),
|
||||
),
|
||||
tuple(
|
||||
@ -36,9 +37,7 @@
|
||||
- Test Device 4
|
||||
- 1 (3)
|
||||
|
||||
Answer the user's questions about the world truthfully.
|
||||
|
||||
If the user wants to control a device, reject the request and suggest using the Home Assistant app.
|
||||
If the user wants to control a device, tell them to edit the AI configuration and allow access to Home Assistant.
|
||||
''',
|
||||
'role': 'user',
|
||||
}),
|
||||
@ -59,7 +58,7 @@
|
||||
),
|
||||
])
|
||||
# ---
|
||||
# name: test_default_prompt[conversation.google_generative_ai_conversation]
|
||||
# name: test_default_prompt[False-conversation.google_generative_ai_conversation]
|
||||
list([
|
||||
tuple(
|
||||
'',
|
||||
@ -73,6 +72,7 @@
|
||||
'top_p': 1.0,
|
||||
}),
|
||||
'model_name': 'models/gemini-pro',
|
||||
'tools': None,
|
||||
}),
|
||||
),
|
||||
tuple(
|
||||
@ -96,9 +96,7 @@
|
||||
- Test Device 4
|
||||
- 1 (3)
|
||||
|
||||
Answer the user's questions about the world truthfully.
|
||||
|
||||
If the user wants to control a device, reject the request and suggest using the Home Assistant app.
|
||||
If the user wants to control a device, tell them to edit the AI configuration and allow access to Home Assistant.
|
||||
''',
|
||||
'role': 'user',
|
||||
}),
|
||||
@ -119,48 +117,118 @@
|
||||
),
|
||||
])
|
||||
# ---
|
||||
# name: test_generate_content_service_with_image
|
||||
# name: test_default_prompt[True-None]
|
||||
list([
|
||||
tuple(
|
||||
'',
|
||||
tuple(
|
||||
),
|
||||
dict({
|
||||
'model_name': 'gemini-pro-vision',
|
||||
'generation_config': dict({
|
||||
'max_output_tokens': 150,
|
||||
'temperature': 0.9,
|
||||
'top_k': 1,
|
||||
'top_p': 1.0,
|
||||
}),
|
||||
'model_name': 'models/gemini-pro',
|
||||
'tools': None,
|
||||
}),
|
||||
),
|
||||
tuple(
|
||||
'().generate_content_async',
|
||||
'().start_chat',
|
||||
tuple(
|
||||
list([
|
||||
'Describe this image from my doorbell camera',
|
||||
),
|
||||
dict({
|
||||
'history': list([
|
||||
dict({
|
||||
'data': b'image bytes',
|
||||
'mime_type': 'image/jpeg',
|
||||
'parts': '''
|
||||
This smart home is controlled by Home Assistant.
|
||||
|
||||
An overview of the areas and the devices in this smart home:
|
||||
|
||||
Test Area:
|
||||
- Test Device (Test Model)
|
||||
|
||||
Test Area 2:
|
||||
- Test Device 2
|
||||
- Test Device 3 (Test Model 3A)
|
||||
- Test Device 4
|
||||
- 1 (3)
|
||||
|
||||
Call the intent tools to control the system. Just pass the name to the intent.
|
||||
''',
|
||||
'role': 'user',
|
||||
}),
|
||||
dict({
|
||||
'parts': 'Ok',
|
||||
'role': 'model',
|
||||
}),
|
||||
]),
|
||||
}),
|
||||
),
|
||||
tuple(
|
||||
'().start_chat().send_message_async',
|
||||
tuple(
|
||||
'hello',
|
||||
),
|
||||
dict({
|
||||
}),
|
||||
),
|
||||
])
|
||||
# ---
|
||||
# name: test_generate_content_service_without_images
|
||||
# name: test_default_prompt[True-conversation.google_generative_ai_conversation]
|
||||
list([
|
||||
tuple(
|
||||
'',
|
||||
tuple(
|
||||
),
|
||||
dict({
|
||||
'model_name': 'gemini-pro',
|
||||
'generation_config': dict({
|
||||
'max_output_tokens': 150,
|
||||
'temperature': 0.9,
|
||||
'top_k': 1,
|
||||
'top_p': 1.0,
|
||||
}),
|
||||
'model_name': 'models/gemini-pro',
|
||||
'tools': None,
|
||||
}),
|
||||
),
|
||||
tuple(
|
||||
'().generate_content_async',
|
||||
'().start_chat',
|
||||
tuple(
|
||||
list([
|
||||
'Write an opening speech for a Home Assistant release party',
|
||||
),
|
||||
dict({
|
||||
'history': list([
|
||||
dict({
|
||||
'parts': '''
|
||||
This smart home is controlled by Home Assistant.
|
||||
|
||||
An overview of the areas and the devices in this smart home:
|
||||
|
||||
Test Area:
|
||||
- Test Device (Test Model)
|
||||
|
||||
Test Area 2:
|
||||
- Test Device 2
|
||||
- Test Device 3 (Test Model 3A)
|
||||
- Test Device 4
|
||||
- 1 (3)
|
||||
|
||||
Call the intent tools to control the system. Just pass the name to the intent.
|
||||
''',
|
||||
'role': 'user',
|
||||
}),
|
||||
dict({
|
||||
'parts': 'Ok',
|
||||
'role': 'model',
|
||||
}),
|
||||
]),
|
||||
}),
|
||||
),
|
||||
tuple(
|
||||
'().start_chat().send_message_async',
|
||||
tuple(
|
||||
'hello',
|
||||
),
|
||||
dict({
|
||||
}),
|
||||
|
@ -1,6 +1,6 @@
|
||||
"""Test the Google Generative AI Conversation config flow."""
|
||||
|
||||
from unittest.mock import patch
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
from google.api_core.exceptions import ClientError
|
||||
from google.rpc.error_details_pb2 import ErrorInfo
|
||||
@ -18,12 +18,35 @@ from homeassistant.components.google_generative_ai_conversation.const import (
|
||||
DEFAULT_TOP_P,
|
||||
DOMAIN,
|
||||
)
|
||||
from homeassistant.const import CONF_LLM_HASS_API
|
||||
from homeassistant.core import HomeAssistant
|
||||
from homeassistant.data_entry_flow import FlowResultType
|
||||
from homeassistant.helpers import llm
|
||||
|
||||
from tests.common import MockConfigEntry
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_models():
|
||||
"""Mock the model list API."""
|
||||
model_15_flash = Mock(
|
||||
display_name="Gemini 1.5 Flash",
|
||||
supported_generation_methods=["generateContent"],
|
||||
)
|
||||
model_15_flash.name = "models/gemini-1.5-flash-latest"
|
||||
|
||||
model_10_pro = Mock(
|
||||
display_name="Gemini 1.0 Pro",
|
||||
supported_generation_methods=["generateContent"],
|
||||
)
|
||||
model_10_pro.name = "models/gemini-pro"
|
||||
with patch(
|
||||
"homeassistant.components.google_generative_ai_conversation.config_flow.genai.list_models",
|
||||
return_value=[model_10_pro],
|
||||
):
|
||||
yield
|
||||
|
||||
|
||||
async def test_form(hass: HomeAssistant) -> None:
|
||||
"""Test we get the form."""
|
||||
# Pretend we already set up a config entry.
|
||||
@ -60,11 +83,14 @@ async def test_form(hass: HomeAssistant) -> None:
|
||||
assert result2["data"] == {
|
||||
"api_key": "bla",
|
||||
}
|
||||
assert result2["options"] == {
|
||||
CONF_LLM_HASS_API: llm.LLM_API_ASSIST,
|
||||
}
|
||||
assert len(mock_setup_entry.mock_calls) == 1
|
||||
|
||||
|
||||
async def test_options(
|
||||
hass: HomeAssistant, mock_config_entry, mock_init_component
|
||||
hass: HomeAssistant, mock_config_entry, mock_init_component, mock_models
|
||||
) -> None:
|
||||
"""Test the options form."""
|
||||
options_flow = await hass.config_entries.options.async_init(
|
||||
@ -85,6 +111,9 @@ async def test_options(
|
||||
assert options["data"][CONF_TOP_P] == DEFAULT_TOP_P
|
||||
assert options["data"][CONF_TOP_K] == DEFAULT_TOP_K
|
||||
assert options["data"][CONF_MAX_TOKENS] == DEFAULT_MAX_TOKENS
|
||||
assert (
|
||||
CONF_LLM_HASS_API not in options["data"]
|
||||
), "Options flow should not set this key"
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
|
@ -5,10 +5,18 @@ from unittest.mock import AsyncMock, MagicMock, patch
|
||||
from google.api_core.exceptions import ClientError
|
||||
import pytest
|
||||
from syrupy.assertion import SnapshotAssertion
|
||||
import voluptuous as vol
|
||||
|
||||
from homeassistant.components import conversation
|
||||
from homeassistant.const import CONF_LLM_HASS_API
|
||||
from homeassistant.core import Context, HomeAssistant
|
||||
from homeassistant.helpers import area_registry as ar, device_registry as dr, intent
|
||||
from homeassistant.exceptions import HomeAssistantError
|
||||
from homeassistant.helpers import (
|
||||
area_registry as ar,
|
||||
device_registry as dr,
|
||||
intent,
|
||||
llm,
|
||||
)
|
||||
|
||||
from tests.common import MockConfigEntry
|
||||
|
||||
@ -16,6 +24,7 @@ from tests.common import MockConfigEntry
|
||||
@pytest.mark.parametrize(
|
||||
"agent_id", [None, "conversation.google_generative_ai_conversation"]
|
||||
)
|
||||
@pytest.mark.parametrize("allow_hass_access", [False, True])
|
||||
async def test_default_prompt(
|
||||
hass: HomeAssistant,
|
||||
mock_config_entry: MockConfigEntry,
|
||||
@ -24,6 +33,7 @@ async def test_default_prompt(
|
||||
device_registry: dr.DeviceRegistry,
|
||||
snapshot: SnapshotAssertion,
|
||||
agent_id: str | None,
|
||||
allow_hass_access: bool,
|
||||
) -> None:
|
||||
"""Test that the default prompt works."""
|
||||
entry = MockConfigEntry(title=None)
|
||||
@ -34,6 +44,15 @@ async def test_default_prompt(
|
||||
if agent_id is None:
|
||||
agent_id = mock_config_entry.entry_id
|
||||
|
||||
if allow_hass_access:
|
||||
hass.config_entries.async_update_entry(
|
||||
mock_config_entry,
|
||||
options={
|
||||
**mock_config_entry.options,
|
||||
CONF_LLM_HASS_API: llm.LLM_API_ASSIST,
|
||||
},
|
||||
)
|
||||
|
||||
device_registry.async_get_or_create(
|
||||
config_entry_id=entry.entry_id,
|
||||
connections={("test", "1234")},
|
||||
@ -100,12 +119,20 @@ async def test_default_prompt(
|
||||
model=3,
|
||||
suggested_area="Test Area 2",
|
||||
)
|
||||
with patch("google.generativeai.GenerativeModel") as mock_model:
|
||||
with (
|
||||
patch("google.generativeai.GenerativeModel") as mock_model,
|
||||
patch(
|
||||
"homeassistant.components.google_generative_ai_conversation.conversation.llm.AssistAPI.async_get_tools",
|
||||
return_value=[],
|
||||
) as mock_get_tools,
|
||||
):
|
||||
mock_chat = AsyncMock()
|
||||
mock_model.return_value.start_chat.return_value = mock_chat
|
||||
chat_response = MagicMock()
|
||||
mock_chat.send_message_async.return_value = chat_response
|
||||
chat_response.parts = ["Hi there!"]
|
||||
mock_part = MagicMock()
|
||||
mock_part.function_call = None
|
||||
chat_response.parts = [mock_part]
|
||||
chat_response.text = "Hi there!"
|
||||
result = await conversation.async_converse(
|
||||
hass,
|
||||
@ -118,6 +145,171 @@ async def test_default_prompt(
|
||||
assert result.response.response_type == intent.IntentResponseType.ACTION_DONE
|
||||
assert result.response.as_dict()["speech"]["plain"]["speech"] == "Hi there!"
|
||||
assert [tuple(mock_call) for mock_call in mock_model.mock_calls] == snapshot
|
||||
assert mock_get_tools.called == allow_hass_access
|
||||
|
||||
|
||||
@patch(
|
||||
"homeassistant.components.google_generative_ai_conversation.conversation.llm.AssistAPI.async_get_tools"
|
||||
)
|
||||
async def test_function_call(
|
||||
mock_get_tools,
|
||||
hass: HomeAssistant,
|
||||
mock_config_entry_with_assist: MockConfigEntry,
|
||||
mock_init_component,
|
||||
) -> None:
|
||||
"""Test that the default prompt works."""
|
||||
agent_id = mock_config_entry_with_assist.entry_id
|
||||
context = Context()
|
||||
|
||||
mock_tool = AsyncMock()
|
||||
mock_tool.name = "test_tool"
|
||||
mock_tool.description = "Test function"
|
||||
mock_tool.parameters = vol.Schema(
|
||||
{
|
||||
vol.Optional("param1", description="Test parameters"): [
|
||||
vol.All(str, vol.Lower)
|
||||
]
|
||||
}
|
||||
)
|
||||
|
||||
mock_get_tools.return_value = [mock_tool]
|
||||
|
||||
with patch("google.generativeai.GenerativeModel") as mock_model:
|
||||
mock_chat = AsyncMock()
|
||||
mock_model.return_value.start_chat.return_value = mock_chat
|
||||
chat_response = MagicMock()
|
||||
mock_chat.send_message_async.return_value = chat_response
|
||||
mock_part = MagicMock()
|
||||
mock_part.function_call.name = "test_tool"
|
||||
mock_part.function_call.args = {"param1": ["test_value"]}
|
||||
|
||||
def tool_call(hass, tool_input):
|
||||
mock_part.function_call = False
|
||||
chat_response.text = "Hi there!"
|
||||
return {"result": "Test response"}
|
||||
|
||||
mock_tool.async_call.side_effect = tool_call
|
||||
chat_response.parts = [mock_part]
|
||||
result = await conversation.async_converse(
|
||||
hass,
|
||||
"Please call the test function",
|
||||
None,
|
||||
context,
|
||||
agent_id=agent_id,
|
||||
)
|
||||
|
||||
assert result.response.response_type == intent.IntentResponseType.ACTION_DONE
|
||||
assert result.response.as_dict()["speech"]["plain"]["speech"] == "Hi there!"
|
||||
mock_tool_call = mock_chat.send_message_async.mock_calls[1][1][0]
|
||||
mock_tool_call = type(mock_tool_call).to_dict(mock_tool_call)
|
||||
assert mock_tool_call == {
|
||||
"parts": [
|
||||
{
|
||||
"function_response": {
|
||||
"name": "test_tool",
|
||||
"response": {
|
||||
"result": "Test response",
|
||||
},
|
||||
},
|
||||
},
|
||||
],
|
||||
"role": "",
|
||||
}
|
||||
|
||||
mock_tool.async_call.assert_awaited_once_with(
|
||||
hass,
|
||||
llm.ToolInput(
|
||||
tool_name="test_tool",
|
||||
tool_args={"param1": ["test_value"]},
|
||||
platform="google_generative_ai_conversation",
|
||||
context=context,
|
||||
user_prompt="Please call the test function",
|
||||
language="en",
|
||||
assistant="conversation",
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
@patch(
|
||||
"homeassistant.components.google_generative_ai_conversation.conversation.llm.AssistAPI.async_get_tools"
|
||||
)
|
||||
async def test_function_exception(
|
||||
mock_get_tools,
|
||||
hass: HomeAssistant,
|
||||
mock_config_entry_with_assist: MockConfigEntry,
|
||||
mock_init_component,
|
||||
) -> None:
|
||||
"""Test that the default prompt works."""
|
||||
agent_id = mock_config_entry_with_assist.entry_id
|
||||
context = Context()
|
||||
|
||||
mock_tool = AsyncMock()
|
||||
mock_tool.name = "test_tool"
|
||||
mock_tool.description = "Test function"
|
||||
mock_tool.parameters = vol.Schema(
|
||||
{
|
||||
vol.Optional("param1", description="Test parameters"): vol.All(
|
||||
vol.Coerce(int), vol.Range(0, 100)
|
||||
)
|
||||
}
|
||||
)
|
||||
|
||||
mock_get_tools.return_value = [mock_tool]
|
||||
|
||||
with patch("google.generativeai.GenerativeModel") as mock_model:
|
||||
mock_chat = AsyncMock()
|
||||
mock_model.return_value.start_chat.return_value = mock_chat
|
||||
chat_response = MagicMock()
|
||||
mock_chat.send_message_async.return_value = chat_response
|
||||
mock_part = MagicMock()
|
||||
mock_part.function_call.name = "test_tool"
|
||||
mock_part.function_call.args = {"param1": 1}
|
||||
|
||||
def tool_call(hass, tool_input):
|
||||
mock_part.function_call = False
|
||||
chat_response.text = "Hi there!"
|
||||
raise HomeAssistantError("Test tool exception")
|
||||
|
||||
mock_tool.async_call.side_effect = tool_call
|
||||
chat_response.parts = [mock_part]
|
||||
result = await conversation.async_converse(
|
||||
hass,
|
||||
"Please call the test function",
|
||||
None,
|
||||
context,
|
||||
agent_id=agent_id,
|
||||
)
|
||||
|
||||
assert result.response.response_type == intent.IntentResponseType.ACTION_DONE
|
||||
assert result.response.as_dict()["speech"]["plain"]["speech"] == "Hi there!"
|
||||
mock_tool_call = mock_chat.send_message_async.mock_calls[1][1][0]
|
||||
mock_tool_call = type(mock_tool_call).to_dict(mock_tool_call)
|
||||
assert mock_tool_call == {
|
||||
"parts": [
|
||||
{
|
||||
"function_response": {
|
||||
"name": "test_tool",
|
||||
"response": {
|
||||
"error": "HomeAssistantError",
|
||||
"error_text": "Test tool exception",
|
||||
},
|
||||
},
|
||||
},
|
||||
],
|
||||
"role": "",
|
||||
}
|
||||
mock_tool.async_call.assert_awaited_once_with(
|
||||
hass,
|
||||
llm.ToolInput(
|
||||
tool_name="test_tool",
|
||||
tool_args={"param1": 1},
|
||||
platform="google_generative_ai_conversation",
|
||||
context=context,
|
||||
user_prompt="Please call the test function",
|
||||
language="en",
|
||||
assistant="conversation",
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
async def test_error_handling(
|
||||
|
@ -18,12 +18,13 @@ async def test_get_api_no_existing(hass: HomeAssistant) -> None:
|
||||
|
||||
async def test_register_api(hass: HomeAssistant) -> None:
|
||||
"""Test registering an llm api."""
|
||||
api = llm.AssistAPI(
|
||||
hass=hass,
|
||||
id="test",
|
||||
name="Test",
|
||||
prompt_template="Test",
|
||||
)
|
||||
|
||||
class MyAPI(llm.API):
|
||||
def async_get_tools(self) -> list[llm.Tool]:
|
||||
"""Return a list of tools."""
|
||||
return []
|
||||
|
||||
api = MyAPI(hass=hass, id="test", name="Test", prompt_template="")
|
||||
llm.async_register_api(hass, api)
|
||||
|
||||
assert llm.async_get_api(hass, "test") is api
|
||||
|
Loading…
x
Reference in New Issue
Block a user