Increase AI Task default tokens for Google Gemini

This commit is contained in:
Paulus Schoutsen
2025-10-23 21:11:31 -04:00
parent 312812dd8b
commit 0dc74d3d50
3 changed files with 19 additions and 4 deletions

View File

@@ -15,7 +15,13 @@ from homeassistant.exceptions import HomeAssistantError
from homeassistant.helpers.entity_platform import AddConfigEntryEntitiesCallback
from homeassistant.util.json import json_loads
from .const import CONF_CHAT_MODEL, CONF_RECOMMENDED, LOGGER, RECOMMENDED_IMAGE_MODEL
from .const import (
CONF_CHAT_MODEL,
CONF_RECOMMENDED,
LOGGER,
RECOMMENDED_A_TASK_MAX_TOKENS,
RECOMMENDED_IMAGE_MODEL,
)
from .entity import (
ERROR_GETTING_RESPONSE,
GoogleGenerativeAILLMBaseEntity,
@@ -73,7 +79,9 @@ class GoogleGenerativeAITaskEntity(
chat_log: conversation.ChatLog,
) -> ai_task.GenDataTaskResult:
"""Handle a generate data task."""
await self._async_handle_chat_log(chat_log, task.structure)
await self._async_handle_chat_log(
chat_log, task.structure, default_max_tokens=RECOMMENDED_A_TASK_MAX_TOKENS
)
if not isinstance(chat_log.content[-1], conversation.AssistantContent):
LOGGER.error(

View File

@@ -32,6 +32,8 @@ CONF_TOP_K = "top_k"
RECOMMENDED_TOP_K = 64
CONF_MAX_TOKENS = "max_tokens"
RECOMMENDED_MAX_TOKENS = 3000
# Input 5000, output 19400 = 0.05 USD
RECOMMENDED_A_TASK_MAX_TOKENS = 19400
CONF_HARASSMENT_BLOCK_THRESHOLD = "harassment_block_threshold"
CONF_HATE_BLOCK_THRESHOLD = "hate_block_threshold"
CONF_SEXUAL_BLOCK_THRESHOLD = "sexual_block_threshold"

View File

@@ -472,6 +472,7 @@ class GoogleGenerativeAILLMBaseEntity(Entity):
self,
chat_log: conversation.ChatLog,
structure: vol.Schema | None = None,
default_max_tokens: int | None = None,
) -> None:
"""Generate an answer for the chat log."""
options = self.subentry.data
@@ -618,7 +619,9 @@ class GoogleGenerativeAILLMBaseEntity(Entity):
if not chat_log.unresponded_tool_results:
break
def create_generate_content_config(self) -> GenerateContentConfig:
def create_generate_content_config(
self, default_max_tokens: int | None = None
) -> GenerateContentConfig:
"""Create the GenerateContentConfig for the LLM."""
options = self.subentry.data
model = options.get(CONF_CHAT_MODEL, self.default_model)
@@ -632,7 +635,9 @@ class GoogleGenerativeAILLMBaseEntity(Entity):
temperature=options.get(CONF_TEMPERATURE, RECOMMENDED_TEMPERATURE),
top_k=options.get(CONF_TOP_K, RECOMMENDED_TOP_K),
top_p=options.get(CONF_TOP_P, RECOMMENDED_TOP_P),
max_output_tokens=options.get(CONF_MAX_TOKENS, RECOMMENDED_MAX_TOKENS),
max_output_tokens=options.get(
CONF_MAX_TOKENS, default_max_tokens or RECOMMENDED_MAX_TOKENS
),
safety_settings=[
SafetySetting(
category=HarmCategory.HARM_CATEGORY_HATE_SPEECH,