mirror of
https://github.com/home-assistant/core.git
synced 2025-11-09 10:59:40 +00:00
Increase AI Task default tokens for Google Gemini
This commit is contained in:
@@ -15,7 +15,13 @@ from homeassistant.exceptions import HomeAssistantError
|
||||
from homeassistant.helpers.entity_platform import AddConfigEntryEntitiesCallback
|
||||
from homeassistant.util.json import json_loads
|
||||
|
||||
from .const import CONF_CHAT_MODEL, CONF_RECOMMENDED, LOGGER, RECOMMENDED_IMAGE_MODEL
|
||||
from .const import (
|
||||
CONF_CHAT_MODEL,
|
||||
CONF_RECOMMENDED,
|
||||
LOGGER,
|
||||
RECOMMENDED_A_TASK_MAX_TOKENS,
|
||||
RECOMMENDED_IMAGE_MODEL,
|
||||
)
|
||||
from .entity import (
|
||||
ERROR_GETTING_RESPONSE,
|
||||
GoogleGenerativeAILLMBaseEntity,
|
||||
@@ -73,7 +79,9 @@ class GoogleGenerativeAITaskEntity(
|
||||
chat_log: conversation.ChatLog,
|
||||
) -> ai_task.GenDataTaskResult:
|
||||
"""Handle a generate data task."""
|
||||
await self._async_handle_chat_log(chat_log, task.structure)
|
||||
await self._async_handle_chat_log(
|
||||
chat_log, task.structure, default_max_tokens=RECOMMENDED_A_TASK_MAX_TOKENS
|
||||
)
|
||||
|
||||
if not isinstance(chat_log.content[-1], conversation.AssistantContent):
|
||||
LOGGER.error(
|
||||
|
||||
@@ -32,6 +32,8 @@ CONF_TOP_K = "top_k"
|
||||
RECOMMENDED_TOP_K = 64
|
||||
CONF_MAX_TOKENS = "max_tokens"
|
||||
RECOMMENDED_MAX_TOKENS = 3000
|
||||
# Input 5000, output 19400 = 0.05 USD
|
||||
RECOMMENDED_A_TASK_MAX_TOKENS = 19400
|
||||
CONF_HARASSMENT_BLOCK_THRESHOLD = "harassment_block_threshold"
|
||||
CONF_HATE_BLOCK_THRESHOLD = "hate_block_threshold"
|
||||
CONF_SEXUAL_BLOCK_THRESHOLD = "sexual_block_threshold"
|
||||
|
||||
@@ -472,6 +472,7 @@ class GoogleGenerativeAILLMBaseEntity(Entity):
|
||||
self,
|
||||
chat_log: conversation.ChatLog,
|
||||
structure: vol.Schema | None = None,
|
||||
default_max_tokens: int | None = None,
|
||||
) -> None:
|
||||
"""Generate an answer for the chat log."""
|
||||
options = self.subentry.data
|
||||
@@ -618,7 +619,9 @@ class GoogleGenerativeAILLMBaseEntity(Entity):
|
||||
if not chat_log.unresponded_tool_results:
|
||||
break
|
||||
|
||||
def create_generate_content_config(self) -> GenerateContentConfig:
|
||||
def create_generate_content_config(
|
||||
self, default_max_tokens: int | None = None
|
||||
) -> GenerateContentConfig:
|
||||
"""Create the GenerateContentConfig for the LLM."""
|
||||
options = self.subentry.data
|
||||
model = options.get(CONF_CHAT_MODEL, self.default_model)
|
||||
@@ -632,7 +635,9 @@ class GoogleGenerativeAILLMBaseEntity(Entity):
|
||||
temperature=options.get(CONF_TEMPERATURE, RECOMMENDED_TEMPERATURE),
|
||||
top_k=options.get(CONF_TOP_K, RECOMMENDED_TOP_K),
|
||||
top_p=options.get(CONF_TOP_P, RECOMMENDED_TOP_P),
|
||||
max_output_tokens=options.get(CONF_MAX_TOKENS, RECOMMENDED_MAX_TOKENS),
|
||||
max_output_tokens=options.get(
|
||||
CONF_MAX_TOKENS, default_max_tokens or RECOMMENDED_MAX_TOKENS
|
||||
),
|
||||
safety_settings=[
|
||||
SafetySetting(
|
||||
category=HarmCategory.HARM_CATEGORY_HATE_SPEECH,
|
||||
|
||||
Reference in New Issue
Block a user