mirror of
https://github.com/home-assistant/core.git
synced 2025-07-23 21:27:38 +00:00
Add AI Task platform to Google Gen AI (#146766)
This commit is contained in:
parent
a3b03caead
commit
04cc451c76
@ -3,6 +3,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from functools import partial
|
||||
import mimetypes
|
||||
from pathlib import Path
|
||||
from types import MappingProxyType
|
||||
@ -37,11 +38,13 @@ from homeassistant.helpers.typing import ConfigType
|
||||
|
||||
from .const import (
|
||||
CONF_PROMPT,
|
||||
DEFAULT_AI_TASK_NAME,
|
||||
DEFAULT_TITLE,
|
||||
DEFAULT_TTS_NAME,
|
||||
DOMAIN,
|
||||
FILE_POLLING_INTERVAL_SECONDS,
|
||||
LOGGER,
|
||||
RECOMMENDED_AI_TASK_OPTIONS,
|
||||
RECOMMENDED_CHAT_MODEL,
|
||||
RECOMMENDED_TTS_OPTIONS,
|
||||
TIMEOUT_MILLIS,
|
||||
@ -53,6 +56,7 @@ CONF_FILENAMES = "filenames"
|
||||
|
||||
CONFIG_SCHEMA = cv.config_entry_only_config_schema(DOMAIN)
|
||||
PLATFORMS = (
|
||||
Platform.AI_TASK,
|
||||
Platform.CONVERSATION,
|
||||
Platform.TTS,
|
||||
)
|
||||
@ -187,11 +191,9 @@ async def async_setup_entry(
|
||||
"""Set up Google Generative AI Conversation from a config entry."""
|
||||
|
||||
try:
|
||||
|
||||
def _init_client() -> Client:
|
||||
return Client(api_key=entry.data[CONF_API_KEY])
|
||||
|
||||
client = await hass.async_add_executor_job(_init_client)
|
||||
client = await hass.async_add_executor_job(
|
||||
partial(Client, api_key=entry.data[CONF_API_KEY])
|
||||
)
|
||||
await client.aio.models.get(
|
||||
model=RECOMMENDED_CHAT_MODEL,
|
||||
config={"http_options": {"timeout": TIMEOUT_MILLIS}},
|
||||
@ -350,6 +352,19 @@ async def async_migrate_entry(
|
||||
|
||||
hass.config_entries.async_update_entry(entry, minor_version=2)
|
||||
|
||||
if entry.version == 2 and entry.minor_version == 2:
|
||||
# Add AI Task subentry with default options
|
||||
hass.config_entries.async_add_subentry(
|
||||
entry,
|
||||
ConfigSubentry(
|
||||
data=MappingProxyType(RECOMMENDED_AI_TASK_OPTIONS),
|
||||
subentry_type="ai_task_data",
|
||||
title=DEFAULT_AI_TASK_NAME,
|
||||
unique_id=None,
|
||||
),
|
||||
)
|
||||
hass.config_entries.async_update_entry(entry, minor_version=3)
|
||||
|
||||
LOGGER.debug(
|
||||
"Migration to version %s:%s successful", entry.version, entry.minor_version
|
||||
)
|
||||
|
@ -0,0 +1,57 @@
|
||||
"""AI Task integration for Google Generative AI Conversation."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from homeassistant.components import ai_task, conversation
|
||||
from homeassistant.config_entries import ConfigEntry
|
||||
from homeassistant.core import HomeAssistant
|
||||
from homeassistant.exceptions import HomeAssistantError
|
||||
from homeassistant.helpers.entity_platform import AddConfigEntryEntitiesCallback
|
||||
|
||||
from .const import LOGGER
|
||||
from .entity import ERROR_GETTING_RESPONSE, GoogleGenerativeAILLMBaseEntity
|
||||
|
||||
|
||||
async def async_setup_entry(
|
||||
hass: HomeAssistant,
|
||||
config_entry: ConfigEntry,
|
||||
async_add_entities: AddConfigEntryEntitiesCallback,
|
||||
) -> None:
|
||||
"""Set up AI Task entities."""
|
||||
for subentry in config_entry.subentries.values():
|
||||
if subentry.subentry_type != "ai_task_data":
|
||||
continue
|
||||
|
||||
async_add_entities(
|
||||
[GoogleGenerativeAITaskEntity(config_entry, subentry)],
|
||||
config_subentry_id=subentry.subentry_id,
|
||||
)
|
||||
|
||||
|
||||
class GoogleGenerativeAITaskEntity(
|
||||
ai_task.AITaskEntity,
|
||||
GoogleGenerativeAILLMBaseEntity,
|
||||
):
|
||||
"""Google Generative AI AI Task entity."""
|
||||
|
||||
_attr_supported_features = ai_task.AITaskEntityFeature.GENERATE_DATA
|
||||
|
||||
async def _async_generate_data(
|
||||
self,
|
||||
task: ai_task.GenDataTask,
|
||||
chat_log: conversation.ChatLog,
|
||||
) -> ai_task.GenDataTaskResult:
|
||||
"""Handle a generate data task."""
|
||||
await self._async_handle_chat_log(chat_log)
|
||||
|
||||
if not isinstance(chat_log.content[-1], conversation.AssistantContent):
|
||||
LOGGER.error(
|
||||
"Last content in chat log is not an AssistantContent: %s. This could be due to the model not returning a valid response",
|
||||
chat_log.content[-1],
|
||||
)
|
||||
raise HomeAssistantError(ERROR_GETTING_RESPONSE)
|
||||
|
||||
return ai_task.GenDataTaskResult(
|
||||
conversation_id=chat_log.conversation_id,
|
||||
data=chat_log.content[-1].content or "",
|
||||
)
|
@ -3,6 +3,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Mapping
|
||||
from functools import partial
|
||||
import logging
|
||||
from typing import Any, cast
|
||||
|
||||
@ -46,10 +47,12 @@ from .const import (
|
||||
CONF_TOP_K,
|
||||
CONF_TOP_P,
|
||||
CONF_USE_GOOGLE_SEARCH_TOOL,
|
||||
DEFAULT_AI_TASK_NAME,
|
||||
DEFAULT_CONVERSATION_NAME,
|
||||
DEFAULT_TITLE,
|
||||
DEFAULT_TTS_NAME,
|
||||
DOMAIN,
|
||||
RECOMMENDED_AI_TASK_OPTIONS,
|
||||
RECOMMENDED_CHAT_MODEL,
|
||||
RECOMMENDED_CONVERSATION_OPTIONS,
|
||||
RECOMMENDED_HARM_BLOCK_THRESHOLD,
|
||||
@ -72,12 +75,14 @@ STEP_API_DATA_SCHEMA = vol.Schema(
|
||||
)
|
||||
|
||||
|
||||
async def validate_input(data: dict[str, Any]) -> None:
|
||||
async def validate_input(hass: HomeAssistant, data: dict[str, Any]) -> None:
|
||||
"""Validate the user input allows us to connect.
|
||||
|
||||
Data has the keys from STEP_USER_DATA_SCHEMA with values provided by the user.
|
||||
"""
|
||||
client = genai.Client(api_key=data[CONF_API_KEY])
|
||||
client = await hass.async_add_executor_job(
|
||||
partial(genai.Client, api_key=data[CONF_API_KEY])
|
||||
)
|
||||
await client.aio.models.list(
|
||||
config={
|
||||
"http_options": {
|
||||
@ -92,7 +97,7 @@ class GoogleGenerativeAIConfigFlow(ConfigFlow, domain=DOMAIN):
|
||||
"""Handle a config flow for Google Generative AI Conversation."""
|
||||
|
||||
VERSION = 2
|
||||
MINOR_VERSION = 2
|
||||
MINOR_VERSION = 3
|
||||
|
||||
async def async_step_api(
|
||||
self, user_input: dict[str, Any] | None = None
|
||||
@ -102,7 +107,7 @@ class GoogleGenerativeAIConfigFlow(ConfigFlow, domain=DOMAIN):
|
||||
if user_input is not None:
|
||||
self._async_abort_entries_match(user_input)
|
||||
try:
|
||||
await validate_input(user_input)
|
||||
await validate_input(self.hass, user_input)
|
||||
except (APIError, Timeout) as err:
|
||||
if isinstance(err, ClientError) and "API_KEY_INVALID" in str(err):
|
||||
errors["base"] = "invalid_auth"
|
||||
@ -133,6 +138,12 @@ class GoogleGenerativeAIConfigFlow(ConfigFlow, domain=DOMAIN):
|
||||
"title": DEFAULT_TTS_NAME,
|
||||
"unique_id": None,
|
||||
},
|
||||
{
|
||||
"subentry_type": "ai_task_data",
|
||||
"data": RECOMMENDED_AI_TASK_OPTIONS,
|
||||
"title": DEFAULT_AI_TASK_NAME,
|
||||
"unique_id": None,
|
||||
},
|
||||
],
|
||||
)
|
||||
return self.async_show_form(
|
||||
@ -181,6 +192,7 @@ class GoogleGenerativeAIConfigFlow(ConfigFlow, domain=DOMAIN):
|
||||
return {
|
||||
"conversation": LLMSubentryFlowHandler,
|
||||
"tts": LLMSubentryFlowHandler,
|
||||
"ai_task_data": LLMSubentryFlowHandler,
|
||||
}
|
||||
|
||||
|
||||
@ -214,6 +226,8 @@ class LLMSubentryFlowHandler(ConfigSubentryFlow):
|
||||
options: dict[str, Any]
|
||||
if self._subentry_type == "tts":
|
||||
options = RECOMMENDED_TTS_OPTIONS.copy()
|
||||
elif self._subentry_type == "ai_task_data":
|
||||
options = RECOMMENDED_AI_TASK_OPTIONS.copy()
|
||||
else:
|
||||
options = RECOMMENDED_CONVERSATION_OPTIONS.copy()
|
||||
else:
|
||||
@ -288,6 +302,8 @@ async def google_generative_ai_config_option_schema(
|
||||
default_name = options[CONF_NAME]
|
||||
elif subentry_type == "tts":
|
||||
default_name = DEFAULT_TTS_NAME
|
||||
elif subentry_type == "ai_task_data":
|
||||
default_name = DEFAULT_AI_TASK_NAME
|
||||
else:
|
||||
default_name = DEFAULT_CONVERSATION_NAME
|
||||
schema: dict[vol.Required | vol.Optional, Any] = {
|
||||
@ -315,6 +331,7 @@ async def google_generative_ai_config_option_schema(
|
||||
),
|
||||
}
|
||||
)
|
||||
|
||||
schema.update(
|
||||
{
|
||||
vol.Required(
|
||||
@ -443,4 +460,5 @@ async def google_generative_ai_config_option_schema(
|
||||
): bool,
|
||||
}
|
||||
)
|
||||
|
||||
return schema
|
||||
|
@ -12,6 +12,7 @@ CONF_PROMPT = "prompt"
|
||||
|
||||
DEFAULT_CONVERSATION_NAME = "Google AI Conversation"
|
||||
DEFAULT_TTS_NAME = "Google AI TTS"
|
||||
DEFAULT_AI_TASK_NAME = "Google AI Task"
|
||||
|
||||
CONF_RECOMMENDED = "recommended"
|
||||
CONF_CHAT_MODEL = "chat_model"
|
||||
@ -35,6 +36,7 @@ RECOMMENDED_USE_GOOGLE_SEARCH_TOOL = False
|
||||
|
||||
TIMEOUT_MILLIS = 10000
|
||||
FILE_POLLING_INTERVAL_SECONDS = 0.05
|
||||
|
||||
RECOMMENDED_CONVERSATION_OPTIONS = {
|
||||
CONF_PROMPT: llm.DEFAULT_INSTRUCTIONS_PROMPT,
|
||||
CONF_LLM_HASS_API: [llm.LLM_API_ASSIST],
|
||||
@ -44,3 +46,7 @@ RECOMMENDED_CONVERSATION_OPTIONS = {
|
||||
RECOMMENDED_TTS_OPTIONS = {
|
||||
CONF_RECOMMENDED: True,
|
||||
}
|
||||
|
||||
RECOMMENDED_AI_TASK_OPTIONS = {
|
||||
CONF_RECOMMENDED: True,
|
||||
}
|
||||
|
@ -88,6 +88,34 @@
|
||||
"entry_not_loaded": "[%key:component::google_generative_ai_conversation::config_subentries::conversation::abort::entry_not_loaded%]",
|
||||
"reconfigure_successful": "[%key:common::config_flow::abort::reconfigure_successful%]"
|
||||
}
|
||||
},
|
||||
"ai_task_data": {
|
||||
"initiate_flow": {
|
||||
"user": "Add Generate data with AI service",
|
||||
"reconfigure": "Reconfigure Generate data with AI service"
|
||||
},
|
||||
"entry_type": "Generate data with AI service",
|
||||
"step": {
|
||||
"set_options": {
|
||||
"data": {
|
||||
"name": "[%key:common::config_flow::data::name%]",
|
||||
"recommended": "[%key:component::google_generative_ai_conversation::config_subentries::conversation::step::set_options::data::recommended%]",
|
||||
"chat_model": "[%key:common::generic::model%]",
|
||||
"temperature": "[%key:component::google_generative_ai_conversation::config_subentries::conversation::step::set_options::data::temperature%]",
|
||||
"top_p": "[%key:component::google_generative_ai_conversation::config_subentries::conversation::step::set_options::data::top_p%]",
|
||||
"top_k": "[%key:component::google_generative_ai_conversation::config_subentries::conversation::step::set_options::data::top_k%]",
|
||||
"max_tokens": "[%key:component::google_generative_ai_conversation::config_subentries::conversation::step::set_options::data::max_tokens%]",
|
||||
"harassment_block_threshold": "[%key:component::google_generative_ai_conversation::config_subentries::conversation::step::set_options::data::harassment_block_threshold%]",
|
||||
"hate_block_threshold": "[%key:component::google_generative_ai_conversation::config_subentries::conversation::step::set_options::data::hate_block_threshold%]",
|
||||
"sexual_block_threshold": "[%key:component::google_generative_ai_conversation::config_subentries::conversation::step::set_options::data::sexual_block_threshold%]",
|
||||
"dangerous_block_threshold": "[%key:component::google_generative_ai_conversation::config_subentries::conversation::step::set_options::data::dangerous_block_threshold%]"
|
||||
}
|
||||
}
|
||||
},
|
||||
"abort": {
|
||||
"entry_not_loaded": "[%key:component::google_generative_ai_conversation::config_subentries::conversation::abort::entry_not_loaded%]",
|
||||
"reconfigure_successful": "[%key:common::config_flow::abort::reconfigure_successful%]"
|
||||
}
|
||||
}
|
||||
},
|
||||
"services": {
|
||||
|
@ -7,6 +7,7 @@ import pytest
|
||||
|
||||
from homeassistant.components.google_generative_ai_conversation.const import (
|
||||
CONF_USE_GOOGLE_SEARCH_TOOL,
|
||||
DEFAULT_AI_TASK_NAME,
|
||||
DEFAULT_CONVERSATION_NAME,
|
||||
DEFAULT_TTS_NAME,
|
||||
)
|
||||
@ -29,6 +30,7 @@ def mock_config_entry(hass: HomeAssistant) -> MockConfigEntry:
|
||||
"api_key": "bla",
|
||||
},
|
||||
version=2,
|
||||
minor_version=3,
|
||||
subentries_data=[
|
||||
{
|
||||
"data": {},
|
||||
@ -44,6 +46,13 @@ def mock_config_entry(hass: HomeAssistant) -> MockConfigEntry:
|
||||
"subentry_id": "ulid-tts",
|
||||
"unique_id": None,
|
||||
},
|
||||
{
|
||||
"data": {},
|
||||
"subentry_type": "ai_task_data",
|
||||
"title": DEFAULT_AI_TASK_NAME,
|
||||
"subentry_id": "ulid-ai-task",
|
||||
"unique_id": None,
|
||||
},
|
||||
],
|
||||
)
|
||||
entry.runtime_data = Mock()
|
||||
|
@ -7,6 +7,14 @@
|
||||
'options': dict({
|
||||
}),
|
||||
'subentries': dict({
|
||||
'ulid-ai-task': dict({
|
||||
'data': dict({
|
||||
}),
|
||||
'subentry_id': 'ulid-ai-task',
|
||||
'subentry_type': 'ai_task_data',
|
||||
'title': 'Google AI Task',
|
||||
'unique_id': None,
|
||||
}),
|
||||
'ulid-conversation': dict({
|
||||
'data': dict({
|
||||
'chat_model': 'models/gemini-2.5-flash',
|
||||
|
@ -32,6 +32,37 @@
|
||||
'sw_version': None,
|
||||
'via_device_id': None,
|
||||
}),
|
||||
DeviceRegistryEntrySnapshot({
|
||||
'area_id': None,
|
||||
'config_entries': <ANY>,
|
||||
'config_entries_subentries': <ANY>,
|
||||
'configuration_url': None,
|
||||
'connections': set({
|
||||
}),
|
||||
'disabled_by': None,
|
||||
'entry_type': <DeviceEntryType.SERVICE: 'service'>,
|
||||
'hw_version': None,
|
||||
'id': <ANY>,
|
||||
'identifiers': set({
|
||||
tuple(
|
||||
'google_generative_ai_conversation',
|
||||
'ulid-ai-task',
|
||||
),
|
||||
}),
|
||||
'is_new': False,
|
||||
'labels': set({
|
||||
}),
|
||||
'manufacturer': 'Google',
|
||||
'model': 'gemini-2.5-flash',
|
||||
'model_id': None,
|
||||
'name': 'Google AI Task',
|
||||
'name_by_user': None,
|
||||
'primary_config_entry': <ANY>,
|
||||
'serial_number': None,
|
||||
'suggested_area': None,
|
||||
'sw_version': None,
|
||||
'via_device_id': None,
|
||||
}),
|
||||
DeviceRegistryEntrySnapshot({
|
||||
'area_id': None,
|
||||
'config_entries': <ANY>,
|
||||
|
@ -0,0 +1,62 @@
|
||||
"""Test AI Task platform of Google Generative AI Conversation integration."""
|
||||
|
||||
from unittest.mock import AsyncMock
|
||||
|
||||
from google.genai.types import GenerateContentResponse
|
||||
import pytest
|
||||
|
||||
from homeassistant.components import ai_task
|
||||
from homeassistant.core import HomeAssistant
|
||||
from homeassistant.helpers import entity_registry as er
|
||||
|
||||
from tests.common import MockConfigEntry
|
||||
from tests.components.conversation import (
|
||||
MockChatLog,
|
||||
mock_chat_log, # noqa: F401
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.usefixtures("mock_init_component")
|
||||
async def test_run_task(
|
||||
hass: HomeAssistant,
|
||||
mock_config_entry: MockConfigEntry,
|
||||
mock_chat_log: MockChatLog, # noqa: F811
|
||||
mock_send_message_stream: AsyncMock,
|
||||
entity_registry: er.EntityRegistry,
|
||||
) -> None:
|
||||
"""Test empty response."""
|
||||
entity_id = "ai_task.google_ai_task"
|
||||
|
||||
# Ensure it's linked to the subentry
|
||||
entity_entry = entity_registry.async_get(entity_id)
|
||||
ai_task_entry = next(
|
||||
iter(
|
||||
entry
|
||||
for entry in mock_config_entry.subentries.values()
|
||||
if entry.subentry_type == "ai_task_data"
|
||||
)
|
||||
)
|
||||
assert entity_entry.config_entry_id == mock_config_entry.entry_id
|
||||
assert entity_entry.config_subentry_id == ai_task_entry.subentry_id
|
||||
|
||||
mock_send_message_stream.return_value = [
|
||||
[
|
||||
GenerateContentResponse(
|
||||
candidates=[
|
||||
{
|
||||
"content": {
|
||||
"parts": [{"text": "Hi there!"}],
|
||||
"role": "model",
|
||||
},
|
||||
}
|
||||
],
|
||||
),
|
||||
],
|
||||
]
|
||||
result = await ai_task.async_generate_data(
|
||||
hass,
|
||||
task_name="Test Task",
|
||||
entity_id=entity_id,
|
||||
instructions="Test prompt",
|
||||
)
|
||||
assert result.data == "Hi there!"
|
@ -19,9 +19,11 @@ from homeassistant.components.google_generative_ai_conversation.const import (
|
||||
CONF_TOP_K,
|
||||
CONF_TOP_P,
|
||||
CONF_USE_GOOGLE_SEARCH_TOOL,
|
||||
DEFAULT_AI_TASK_NAME,
|
||||
DEFAULT_CONVERSATION_NAME,
|
||||
DEFAULT_TTS_NAME,
|
||||
DOMAIN,
|
||||
RECOMMENDED_AI_TASK_OPTIONS,
|
||||
RECOMMENDED_CHAT_MODEL,
|
||||
RECOMMENDED_CONVERSATION_OPTIONS,
|
||||
RECOMMENDED_HARM_BLOCK_THRESHOLD,
|
||||
@ -121,6 +123,12 @@ async def test_form(hass: HomeAssistant) -> None:
|
||||
"title": DEFAULT_TTS_NAME,
|
||||
"unique_id": None,
|
||||
},
|
||||
{
|
||||
"subentry_type": "ai_task_data",
|
||||
"data": RECOMMENDED_AI_TASK_OPTIONS,
|
||||
"title": DEFAULT_AI_TASK_NAME,
|
||||
"unique_id": None,
|
||||
},
|
||||
]
|
||||
assert len(mock_setup_entry.mock_calls) == 1
|
||||
|
||||
@ -222,7 +230,7 @@ async def test_creating_tts_subentry(
|
||||
assert result2["title"] == "Mock TTS"
|
||||
assert result2["data"] == RECOMMENDED_TTS_OPTIONS
|
||||
|
||||
assert len(mock_config_entry.subentries) == 3
|
||||
assert len(mock_config_entry.subentries) == 4
|
||||
|
||||
new_subentry_id = list(set(mock_config_entry.subentries) - old_subentries)[0]
|
||||
new_subentry = mock_config_entry.subentries[new_subentry_id]
|
||||
@ -232,13 +240,59 @@ async def test_creating_tts_subentry(
|
||||
assert new_subentry.title == "Mock TTS"
|
||||
|
||||
|
||||
async def test_creating_ai_task_subentry(
|
||||
hass: HomeAssistant,
|
||||
mock_init_component: None,
|
||||
mock_config_entry: MockConfigEntry,
|
||||
) -> None:
|
||||
"""Test creating an AI task subentry."""
|
||||
with patch(
|
||||
"google.genai.models.AsyncModels.list",
|
||||
return_value=get_models_pager(),
|
||||
):
|
||||
result = await hass.config_entries.subentries.async_init(
|
||||
(mock_config_entry.entry_id, "ai_task_data"),
|
||||
context={"source": config_entries.SOURCE_USER},
|
||||
)
|
||||
|
||||
assert result["type"] is FlowResultType.FORM, result
|
||||
assert result["step_id"] == "set_options"
|
||||
assert not result["errors"]
|
||||
|
||||
old_subentries = set(mock_config_entry.subentries)
|
||||
|
||||
with patch(
|
||||
"google.genai.models.AsyncModels.list",
|
||||
return_value=get_models_pager(),
|
||||
):
|
||||
result2 = await hass.config_entries.subentries.async_configure(
|
||||
result["flow_id"],
|
||||
{CONF_NAME: "Mock AI Task", **RECOMMENDED_AI_TASK_OPTIONS},
|
||||
)
|
||||
await hass.async_block_till_done()
|
||||
|
||||
assert result2["type"] is FlowResultType.CREATE_ENTRY
|
||||
assert result2["title"] == "Mock AI Task"
|
||||
assert result2["data"] == RECOMMENDED_AI_TASK_OPTIONS
|
||||
|
||||
assert len(mock_config_entry.subentries) == 4
|
||||
|
||||
new_subentry_id = list(set(mock_config_entry.subentries) - old_subentries)[0]
|
||||
new_subentry = mock_config_entry.subentries[new_subentry_id]
|
||||
|
||||
assert new_subentry.subentry_type == "ai_task_data"
|
||||
assert new_subentry.data == RECOMMENDED_AI_TASK_OPTIONS
|
||||
assert new_subentry.title == "Mock AI Task"
|
||||
|
||||
|
||||
async def test_creating_conversation_subentry_not_loaded(
|
||||
hass: HomeAssistant,
|
||||
mock_init_component: None,
|
||||
mock_config_entry: MockConfigEntry,
|
||||
) -> None:
|
||||
"""Test creating a conversation subentry."""
|
||||
"""Test that subentry fails to init if entry not loaded."""
|
||||
await hass.config_entries.async_unload(mock_config_entry.entry_id)
|
||||
|
||||
with patch(
|
||||
"google.genai.models.AsyncModels.list",
|
||||
return_value=get_models_pager(),
|
||||
|
@ -8,9 +8,13 @@ from requests.exceptions import Timeout
|
||||
from syrupy.assertion import SnapshotAssertion
|
||||
|
||||
from homeassistant.components.google_generative_ai_conversation.const import (
|
||||
DEFAULT_AI_TASK_NAME,
|
||||
DEFAULT_CONVERSATION_NAME,
|
||||
DEFAULT_TITLE,
|
||||
DEFAULT_TTS_NAME,
|
||||
DOMAIN,
|
||||
RECOMMENDED_AI_TASK_OPTIONS,
|
||||
RECOMMENDED_CONVERSATION_OPTIONS,
|
||||
RECOMMENDED_TTS_OPTIONS,
|
||||
)
|
||||
from homeassistant.config_entries import ConfigEntryState, ConfigSubentryData
|
||||
@ -397,7 +401,7 @@ async def test_load_entry_with_unloaded_entries(
|
||||
assert [tuple(mock_call) for mock_call in mock_generate.mock_calls] == snapshot
|
||||
|
||||
|
||||
async def test_migration_from_v1_to_v2(
|
||||
async def test_migration_from_v1(
|
||||
hass: HomeAssistant,
|
||||
device_registry: dr.DeviceRegistry,
|
||||
entity_registry: er.EntityRegistry,
|
||||
@ -473,10 +477,10 @@ async def test_migration_from_v1_to_v2(
|
||||
assert len(entries) == 1
|
||||
entry = entries[0]
|
||||
assert entry.version == 2
|
||||
assert entry.minor_version == 2
|
||||
assert entry.minor_version == 3
|
||||
assert not entry.options
|
||||
assert entry.title == DEFAULT_TITLE
|
||||
assert len(entry.subentries) == 3
|
||||
assert len(entry.subentries) == 4
|
||||
conversation_subentries = [
|
||||
subentry
|
||||
for subentry in entry.subentries.values()
|
||||
@ -495,6 +499,14 @@ async def test_migration_from_v1_to_v2(
|
||||
assert len(tts_subentries) == 1
|
||||
assert tts_subentries[0].data == RECOMMENDED_TTS_OPTIONS
|
||||
assert tts_subentries[0].title == DEFAULT_TTS_NAME
|
||||
ai_task_subentries = [
|
||||
subentry
|
||||
for subentry in entry.subentries.values()
|
||||
if subentry.subentry_type == "ai_task_data"
|
||||
]
|
||||
assert len(ai_task_subentries) == 1
|
||||
assert ai_task_subentries[0].data == RECOMMENDED_AI_TASK_OPTIONS
|
||||
assert ai_task_subentries[0].title == DEFAULT_AI_TASK_NAME
|
||||
|
||||
subentry = conversation_subentries[0]
|
||||
|
||||
@ -542,7 +554,7 @@ async def test_migration_from_v1_to_v2(
|
||||
}
|
||||
|
||||
|
||||
async def test_migration_from_v1_to_v2_with_multiple_keys(
|
||||
async def test_migration_from_v1_with_multiple_keys(
|
||||
hass: HomeAssistant,
|
||||
device_registry: dr.DeviceRegistry,
|
||||
entity_registry: er.EntityRegistry,
|
||||
@ -619,10 +631,10 @@ async def test_migration_from_v1_to_v2_with_multiple_keys(
|
||||
|
||||
for entry in entries:
|
||||
assert entry.version == 2
|
||||
assert entry.minor_version == 2
|
||||
assert entry.minor_version == 3
|
||||
assert not entry.options
|
||||
assert entry.title == DEFAULT_TITLE
|
||||
assert len(entry.subentries) == 2
|
||||
assert len(entry.subentries) == 3
|
||||
subentry = list(entry.subentries.values())[0]
|
||||
assert subentry.subentry_type == "conversation"
|
||||
assert subentry.data == options
|
||||
@ -631,6 +643,10 @@ async def test_migration_from_v1_to_v2_with_multiple_keys(
|
||||
assert subentry.subentry_type == "tts"
|
||||
assert subentry.data == RECOMMENDED_TTS_OPTIONS
|
||||
assert subentry.title == DEFAULT_TTS_NAME
|
||||
subentry = list(entry.subentries.values())[2]
|
||||
assert subentry.subentry_type == "ai_task_data"
|
||||
assert subentry.data == RECOMMENDED_AI_TASK_OPTIONS
|
||||
assert subentry.title == DEFAULT_AI_TASK_NAME
|
||||
|
||||
dev = device_registry.async_get_device(
|
||||
identifiers={(DOMAIN, list(entry.subentries.values())[0].subentry_id)}
|
||||
@ -642,7 +658,7 @@ async def test_migration_from_v1_to_v2_with_multiple_keys(
|
||||
}
|
||||
|
||||
|
||||
async def test_migration_from_v1_to_v2_with_same_keys(
|
||||
async def test_migration_from_v1_with_same_keys(
|
||||
hass: HomeAssistant,
|
||||
device_registry: dr.DeviceRegistry,
|
||||
entity_registry: er.EntityRegistry,
|
||||
@ -718,10 +734,10 @@ async def test_migration_from_v1_to_v2_with_same_keys(
|
||||
assert len(entries) == 1
|
||||
entry = entries[0]
|
||||
assert entry.version == 2
|
||||
assert entry.minor_version == 2
|
||||
assert entry.minor_version == 3
|
||||
assert not entry.options
|
||||
assert entry.title == DEFAULT_TITLE
|
||||
assert len(entry.subentries) == 3
|
||||
assert len(entry.subentries) == 4
|
||||
conversation_subentries = [
|
||||
subentry
|
||||
for subentry in entry.subentries.values()
|
||||
@ -740,6 +756,14 @@ async def test_migration_from_v1_to_v2_with_same_keys(
|
||||
assert len(tts_subentries) == 1
|
||||
assert tts_subentries[0].data == RECOMMENDED_TTS_OPTIONS
|
||||
assert tts_subentries[0].title == DEFAULT_TTS_NAME
|
||||
ai_task_subentries = [
|
||||
subentry
|
||||
for subentry in entry.subentries.values()
|
||||
if subentry.subentry_type == "ai_task_data"
|
||||
]
|
||||
assert len(ai_task_subentries) == 1
|
||||
assert ai_task_subentries[0].data == RECOMMENDED_AI_TASK_OPTIONS
|
||||
assert ai_task_subentries[0].title == DEFAULT_AI_TASK_NAME
|
||||
|
||||
subentry = conversation_subentries[0]
|
||||
|
||||
@ -829,7 +853,7 @@ async def test_migration_from_v1_to_v2_with_same_keys(
|
||||
),
|
||||
],
|
||||
)
|
||||
async def test_migration_from_v2_1_to_v2_2(
|
||||
async def test_migration_from_v2_1(
|
||||
hass: HomeAssistant,
|
||||
device_registry: dr.DeviceRegistry,
|
||||
entity_registry: er.EntityRegistry,
|
||||
@ -837,12 +861,13 @@ async def test_migration_from_v2_1_to_v2_2(
|
||||
extra_subentries: list[ConfigSubentryData],
|
||||
expected_device_subentries: dict[str, set[str | None]],
|
||||
) -> None:
|
||||
"""Test migration from version 2.1 to version 2.2.
|
||||
"""Test migration from version 2.1.
|
||||
|
||||
This tests we clean up the broken migration in Home Assistant Core
|
||||
2025.7.0b0-2025.7.0b1:
|
||||
2025.7.0b0-2025.7.0b1 and add AI Task subentry:
|
||||
- Fix device registry (Fixed in Home Assistant Core 2025.7.0b2)
|
||||
- Add TTS subentry (Added in Home Assistant Core 2025.7.0b1)
|
||||
- Add AI Task subentry (Added in version 2.3)
|
||||
"""
|
||||
# Create a v2.1 config entry with 2 subentries, devices and entities
|
||||
options = {
|
||||
@ -930,10 +955,10 @@ async def test_migration_from_v2_1_to_v2_2(
|
||||
assert len(entries) == 1
|
||||
entry = entries[0]
|
||||
assert entry.version == 2
|
||||
assert entry.minor_version == 2
|
||||
assert entry.minor_version == 3
|
||||
assert not entry.options
|
||||
assert entry.title == DEFAULT_TITLE
|
||||
assert len(entry.subentries) == 3
|
||||
assert len(entry.subentries) == 4
|
||||
conversation_subentries = [
|
||||
subentry
|
||||
for subentry in entry.subentries.values()
|
||||
@ -952,6 +977,14 @@ async def test_migration_from_v2_1_to_v2_2(
|
||||
assert len(tts_subentries) == 1
|
||||
assert tts_subentries[0].data == RECOMMENDED_TTS_OPTIONS
|
||||
assert tts_subentries[0].title == DEFAULT_TTS_NAME
|
||||
ai_task_subentries = [
|
||||
subentry
|
||||
for subentry in entry.subentries.values()
|
||||
if subentry.subentry_type == "ai_task_data"
|
||||
]
|
||||
assert len(ai_task_subentries) == 1
|
||||
assert ai_task_subentries[0].data == RECOMMENDED_AI_TASK_OPTIONS
|
||||
assert ai_task_subentries[0].title == DEFAULT_AI_TASK_NAME
|
||||
|
||||
subentry = conversation_subentries[0]
|
||||
|
||||
@ -1011,3 +1044,80 @@ async def test_devices(
|
||||
device_registry, mock_config_entry.entry_id
|
||||
)
|
||||
assert devices == snapshot
|
||||
|
||||
|
||||
async def test_migrate_entry_from_v2_2(hass: HomeAssistant) -> None:
|
||||
"""Test migration from version 2.2."""
|
||||
# Create a v2.2 config entry with conversation and TTS subentries
|
||||
mock_config_entry = MockConfigEntry(
|
||||
domain=DOMAIN,
|
||||
data={CONF_API_KEY: "test-api-key"},
|
||||
version=2,
|
||||
minor_version=2,
|
||||
subentries_data=[
|
||||
{
|
||||
"data": RECOMMENDED_CONVERSATION_OPTIONS,
|
||||
"subentry_type": "conversation",
|
||||
"title": DEFAULT_CONVERSATION_NAME,
|
||||
"unique_id": None,
|
||||
},
|
||||
{
|
||||
"data": RECOMMENDED_TTS_OPTIONS,
|
||||
"subentry_type": "tts",
|
||||
"title": DEFAULT_TTS_NAME,
|
||||
"unique_id": None,
|
||||
},
|
||||
],
|
||||
)
|
||||
mock_config_entry.add_to_hass(hass)
|
||||
|
||||
# Verify initial state
|
||||
assert mock_config_entry.version == 2
|
||||
assert mock_config_entry.minor_version == 2
|
||||
assert len(mock_config_entry.subentries) == 2
|
||||
|
||||
# Run setup to trigger migration
|
||||
with patch(
|
||||
"homeassistant.components.google_generative_ai_conversation.async_setup_entry",
|
||||
return_value=True,
|
||||
):
|
||||
result = await hass.config_entries.async_setup(mock_config_entry.entry_id)
|
||||
assert result is True
|
||||
await hass.async_block_till_done()
|
||||
|
||||
# Verify migration completed
|
||||
entries = hass.config_entries.async_entries(DOMAIN)
|
||||
assert len(entries) == 1
|
||||
entry = entries[0]
|
||||
|
||||
# Check version and subversion were updated
|
||||
assert entry.version == 2
|
||||
assert entry.minor_version == 3
|
||||
|
||||
# Check we now have conversation, tts and ai_task_data subentries
|
||||
assert len(entry.subentries) == 3
|
||||
|
||||
subentries = {
|
||||
subentry.subentry_type: subentry for subentry in entry.subentries.values()
|
||||
}
|
||||
assert "conversation" in subentries
|
||||
assert "tts" in subentries
|
||||
assert "ai_task_data" in subentries
|
||||
|
||||
# Find and verify the ai_task_data subentry
|
||||
ai_task_subentry = subentries["ai_task_data"]
|
||||
assert ai_task_subentry is not None
|
||||
assert ai_task_subentry.title == DEFAULT_AI_TASK_NAME
|
||||
assert ai_task_subentry.data == RECOMMENDED_AI_TASK_OPTIONS
|
||||
|
||||
# Verify conversation subentry is still there and unchanged
|
||||
conversation_subentry = subentries["conversation"]
|
||||
assert conversation_subentry is not None
|
||||
assert conversation_subentry.title == DEFAULT_CONVERSATION_NAME
|
||||
assert conversation_subentry.data == RECOMMENDED_CONVERSATION_OPTIONS
|
||||
|
||||
# Verify TTS subentry is still there and unchanged
|
||||
tts_subentry = subentries["tts"]
|
||||
assert tts_subentry is not None
|
||||
assert tts_subentry.title == DEFAULT_TTS_NAME
|
||||
assert tts_subentry.data == RECOMMENDED_TTS_OPTIONS
|
||||
|
Loading…
x
Reference in New Issue
Block a user