Add AI Task platform to Google Gen AI (#146766)

This commit is contained in:
Paulus Schoutsen 2025-07-04 08:36:34 +02:00 committed by GitHub
parent a3b03caead
commit 04cc451c76
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
11 changed files with 423 additions and 25 deletions

View File

@ -3,6 +3,7 @@
from __future__ import annotations from __future__ import annotations
import asyncio import asyncio
from functools import partial
import mimetypes import mimetypes
from pathlib import Path from pathlib import Path
from types import MappingProxyType from types import MappingProxyType
@ -37,11 +38,13 @@ from homeassistant.helpers.typing import ConfigType
from .const import ( from .const import (
CONF_PROMPT, CONF_PROMPT,
DEFAULT_AI_TASK_NAME,
DEFAULT_TITLE, DEFAULT_TITLE,
DEFAULT_TTS_NAME, DEFAULT_TTS_NAME,
DOMAIN, DOMAIN,
FILE_POLLING_INTERVAL_SECONDS, FILE_POLLING_INTERVAL_SECONDS,
LOGGER, LOGGER,
RECOMMENDED_AI_TASK_OPTIONS,
RECOMMENDED_CHAT_MODEL, RECOMMENDED_CHAT_MODEL,
RECOMMENDED_TTS_OPTIONS, RECOMMENDED_TTS_OPTIONS,
TIMEOUT_MILLIS, TIMEOUT_MILLIS,
@ -53,6 +56,7 @@ CONF_FILENAMES = "filenames"
CONFIG_SCHEMA = cv.config_entry_only_config_schema(DOMAIN) CONFIG_SCHEMA = cv.config_entry_only_config_schema(DOMAIN)
PLATFORMS = ( PLATFORMS = (
Platform.AI_TASK,
Platform.CONVERSATION, Platform.CONVERSATION,
Platform.TTS, Platform.TTS,
) )
@ -187,11 +191,9 @@ async def async_setup_entry(
"""Set up Google Generative AI Conversation from a config entry.""" """Set up Google Generative AI Conversation from a config entry."""
try: try:
client = await hass.async_add_executor_job(
def _init_client() -> Client: partial(Client, api_key=entry.data[CONF_API_KEY])
return Client(api_key=entry.data[CONF_API_KEY]) )
client = await hass.async_add_executor_job(_init_client)
await client.aio.models.get( await client.aio.models.get(
model=RECOMMENDED_CHAT_MODEL, model=RECOMMENDED_CHAT_MODEL,
config={"http_options": {"timeout": TIMEOUT_MILLIS}}, config={"http_options": {"timeout": TIMEOUT_MILLIS}},
@ -350,6 +352,19 @@ async def async_migrate_entry(
hass.config_entries.async_update_entry(entry, minor_version=2) hass.config_entries.async_update_entry(entry, minor_version=2)
if entry.version == 2 and entry.minor_version == 2:
# Add AI Task subentry with default options
hass.config_entries.async_add_subentry(
entry,
ConfigSubentry(
data=MappingProxyType(RECOMMENDED_AI_TASK_OPTIONS),
subentry_type="ai_task_data",
title=DEFAULT_AI_TASK_NAME,
unique_id=None,
),
)
hass.config_entries.async_update_entry(entry, minor_version=3)
LOGGER.debug( LOGGER.debug(
"Migration to version %s:%s successful", entry.version, entry.minor_version "Migration to version %s:%s successful", entry.version, entry.minor_version
) )

View File

@ -0,0 +1,57 @@
"""AI Task integration for Google Generative AI Conversation."""
from __future__ import annotations
from homeassistant.components import ai_task, conversation
from homeassistant.config_entries import ConfigEntry
from homeassistant.core import HomeAssistant
from homeassistant.exceptions import HomeAssistantError
from homeassistant.helpers.entity_platform import AddConfigEntryEntitiesCallback
from .const import LOGGER
from .entity import ERROR_GETTING_RESPONSE, GoogleGenerativeAILLMBaseEntity
async def async_setup_entry(
hass: HomeAssistant,
config_entry: ConfigEntry,
async_add_entities: AddConfigEntryEntitiesCallback,
) -> None:
"""Set up AI Task entities."""
for subentry in config_entry.subentries.values():
if subentry.subentry_type != "ai_task_data":
continue
async_add_entities(
[GoogleGenerativeAITaskEntity(config_entry, subentry)],
config_subentry_id=subentry.subentry_id,
)
class GoogleGenerativeAITaskEntity(
ai_task.AITaskEntity,
GoogleGenerativeAILLMBaseEntity,
):
"""Google Generative AI AI Task entity."""
_attr_supported_features = ai_task.AITaskEntityFeature.GENERATE_DATA
async def _async_generate_data(
self,
task: ai_task.GenDataTask,
chat_log: conversation.ChatLog,
) -> ai_task.GenDataTaskResult:
"""Handle a generate data task."""
await self._async_handle_chat_log(chat_log)
if not isinstance(chat_log.content[-1], conversation.AssistantContent):
LOGGER.error(
"Last content in chat log is not an AssistantContent: %s. This could be due to the model not returning a valid response",
chat_log.content[-1],
)
raise HomeAssistantError(ERROR_GETTING_RESPONSE)
return ai_task.GenDataTaskResult(
conversation_id=chat_log.conversation_id,
data=chat_log.content[-1].content or "",
)

View File

@ -3,6 +3,7 @@
from __future__ import annotations from __future__ import annotations
from collections.abc import Mapping from collections.abc import Mapping
from functools import partial
import logging import logging
from typing import Any, cast from typing import Any, cast
@ -46,10 +47,12 @@ from .const import (
CONF_TOP_K, CONF_TOP_K,
CONF_TOP_P, CONF_TOP_P,
CONF_USE_GOOGLE_SEARCH_TOOL, CONF_USE_GOOGLE_SEARCH_TOOL,
DEFAULT_AI_TASK_NAME,
DEFAULT_CONVERSATION_NAME, DEFAULT_CONVERSATION_NAME,
DEFAULT_TITLE, DEFAULT_TITLE,
DEFAULT_TTS_NAME, DEFAULT_TTS_NAME,
DOMAIN, DOMAIN,
RECOMMENDED_AI_TASK_OPTIONS,
RECOMMENDED_CHAT_MODEL, RECOMMENDED_CHAT_MODEL,
RECOMMENDED_CONVERSATION_OPTIONS, RECOMMENDED_CONVERSATION_OPTIONS,
RECOMMENDED_HARM_BLOCK_THRESHOLD, RECOMMENDED_HARM_BLOCK_THRESHOLD,
@ -72,12 +75,14 @@ STEP_API_DATA_SCHEMA = vol.Schema(
) )
async def validate_input(data: dict[str, Any]) -> None: async def validate_input(hass: HomeAssistant, data: dict[str, Any]) -> None:
"""Validate the user input allows us to connect. """Validate the user input allows us to connect.
Data has the keys from STEP_USER_DATA_SCHEMA with values provided by the user. Data has the keys from STEP_USER_DATA_SCHEMA with values provided by the user.
""" """
client = genai.Client(api_key=data[CONF_API_KEY]) client = await hass.async_add_executor_job(
partial(genai.Client, api_key=data[CONF_API_KEY])
)
await client.aio.models.list( await client.aio.models.list(
config={ config={
"http_options": { "http_options": {
@ -92,7 +97,7 @@ class GoogleGenerativeAIConfigFlow(ConfigFlow, domain=DOMAIN):
"""Handle a config flow for Google Generative AI Conversation.""" """Handle a config flow for Google Generative AI Conversation."""
VERSION = 2 VERSION = 2
MINOR_VERSION = 2 MINOR_VERSION = 3
async def async_step_api( async def async_step_api(
self, user_input: dict[str, Any] | None = None self, user_input: dict[str, Any] | None = None
@ -102,7 +107,7 @@ class GoogleGenerativeAIConfigFlow(ConfigFlow, domain=DOMAIN):
if user_input is not None: if user_input is not None:
self._async_abort_entries_match(user_input) self._async_abort_entries_match(user_input)
try: try:
await validate_input(user_input) await validate_input(self.hass, user_input)
except (APIError, Timeout) as err: except (APIError, Timeout) as err:
if isinstance(err, ClientError) and "API_KEY_INVALID" in str(err): if isinstance(err, ClientError) and "API_KEY_INVALID" in str(err):
errors["base"] = "invalid_auth" errors["base"] = "invalid_auth"
@ -133,6 +138,12 @@ class GoogleGenerativeAIConfigFlow(ConfigFlow, domain=DOMAIN):
"title": DEFAULT_TTS_NAME, "title": DEFAULT_TTS_NAME,
"unique_id": None, "unique_id": None,
}, },
{
"subentry_type": "ai_task_data",
"data": RECOMMENDED_AI_TASK_OPTIONS,
"title": DEFAULT_AI_TASK_NAME,
"unique_id": None,
},
], ],
) )
return self.async_show_form( return self.async_show_form(
@ -181,6 +192,7 @@ class GoogleGenerativeAIConfigFlow(ConfigFlow, domain=DOMAIN):
return { return {
"conversation": LLMSubentryFlowHandler, "conversation": LLMSubentryFlowHandler,
"tts": LLMSubentryFlowHandler, "tts": LLMSubentryFlowHandler,
"ai_task_data": LLMSubentryFlowHandler,
} }
@ -214,6 +226,8 @@ class LLMSubentryFlowHandler(ConfigSubentryFlow):
options: dict[str, Any] options: dict[str, Any]
if self._subentry_type == "tts": if self._subentry_type == "tts":
options = RECOMMENDED_TTS_OPTIONS.copy() options = RECOMMENDED_TTS_OPTIONS.copy()
elif self._subentry_type == "ai_task_data":
options = RECOMMENDED_AI_TASK_OPTIONS.copy()
else: else:
options = RECOMMENDED_CONVERSATION_OPTIONS.copy() options = RECOMMENDED_CONVERSATION_OPTIONS.copy()
else: else:
@ -288,6 +302,8 @@ async def google_generative_ai_config_option_schema(
default_name = options[CONF_NAME] default_name = options[CONF_NAME]
elif subentry_type == "tts": elif subentry_type == "tts":
default_name = DEFAULT_TTS_NAME default_name = DEFAULT_TTS_NAME
elif subentry_type == "ai_task_data":
default_name = DEFAULT_AI_TASK_NAME
else: else:
default_name = DEFAULT_CONVERSATION_NAME default_name = DEFAULT_CONVERSATION_NAME
schema: dict[vol.Required | vol.Optional, Any] = { schema: dict[vol.Required | vol.Optional, Any] = {
@ -315,6 +331,7 @@ async def google_generative_ai_config_option_schema(
), ),
} }
) )
schema.update( schema.update(
{ {
vol.Required( vol.Required(
@ -443,4 +460,5 @@ async def google_generative_ai_config_option_schema(
): bool, ): bool,
} }
) )
return schema return schema

View File

@ -12,6 +12,7 @@ CONF_PROMPT = "prompt"
DEFAULT_CONVERSATION_NAME = "Google AI Conversation" DEFAULT_CONVERSATION_NAME = "Google AI Conversation"
DEFAULT_TTS_NAME = "Google AI TTS" DEFAULT_TTS_NAME = "Google AI TTS"
DEFAULT_AI_TASK_NAME = "Google AI Task"
CONF_RECOMMENDED = "recommended" CONF_RECOMMENDED = "recommended"
CONF_CHAT_MODEL = "chat_model" CONF_CHAT_MODEL = "chat_model"
@ -35,6 +36,7 @@ RECOMMENDED_USE_GOOGLE_SEARCH_TOOL = False
TIMEOUT_MILLIS = 10000 TIMEOUT_MILLIS = 10000
FILE_POLLING_INTERVAL_SECONDS = 0.05 FILE_POLLING_INTERVAL_SECONDS = 0.05
RECOMMENDED_CONVERSATION_OPTIONS = { RECOMMENDED_CONVERSATION_OPTIONS = {
CONF_PROMPT: llm.DEFAULT_INSTRUCTIONS_PROMPT, CONF_PROMPT: llm.DEFAULT_INSTRUCTIONS_PROMPT,
CONF_LLM_HASS_API: [llm.LLM_API_ASSIST], CONF_LLM_HASS_API: [llm.LLM_API_ASSIST],
@ -44,3 +46,7 @@ RECOMMENDED_CONVERSATION_OPTIONS = {
RECOMMENDED_TTS_OPTIONS = { RECOMMENDED_TTS_OPTIONS = {
CONF_RECOMMENDED: True, CONF_RECOMMENDED: True,
} }
RECOMMENDED_AI_TASK_OPTIONS = {
CONF_RECOMMENDED: True,
}

View File

@ -88,6 +88,34 @@
"entry_not_loaded": "[%key:component::google_generative_ai_conversation::config_subentries::conversation::abort::entry_not_loaded%]", "entry_not_loaded": "[%key:component::google_generative_ai_conversation::config_subentries::conversation::abort::entry_not_loaded%]",
"reconfigure_successful": "[%key:common::config_flow::abort::reconfigure_successful%]" "reconfigure_successful": "[%key:common::config_flow::abort::reconfigure_successful%]"
} }
},
"ai_task_data": {
"initiate_flow": {
"user": "Add Generate data with AI service",
"reconfigure": "Reconfigure Generate data with AI service"
},
"entry_type": "Generate data with AI service",
"step": {
"set_options": {
"data": {
"name": "[%key:common::config_flow::data::name%]",
"recommended": "[%key:component::google_generative_ai_conversation::config_subentries::conversation::step::set_options::data::recommended%]",
"chat_model": "[%key:common::generic::model%]",
"temperature": "[%key:component::google_generative_ai_conversation::config_subentries::conversation::step::set_options::data::temperature%]",
"top_p": "[%key:component::google_generative_ai_conversation::config_subentries::conversation::step::set_options::data::top_p%]",
"top_k": "[%key:component::google_generative_ai_conversation::config_subentries::conversation::step::set_options::data::top_k%]",
"max_tokens": "[%key:component::google_generative_ai_conversation::config_subentries::conversation::step::set_options::data::max_tokens%]",
"harassment_block_threshold": "[%key:component::google_generative_ai_conversation::config_subentries::conversation::step::set_options::data::harassment_block_threshold%]",
"hate_block_threshold": "[%key:component::google_generative_ai_conversation::config_subentries::conversation::step::set_options::data::hate_block_threshold%]",
"sexual_block_threshold": "[%key:component::google_generative_ai_conversation::config_subentries::conversation::step::set_options::data::sexual_block_threshold%]",
"dangerous_block_threshold": "[%key:component::google_generative_ai_conversation::config_subentries::conversation::step::set_options::data::dangerous_block_threshold%]"
}
}
},
"abort": {
"entry_not_loaded": "[%key:component::google_generative_ai_conversation::config_subentries::conversation::abort::entry_not_loaded%]",
"reconfigure_successful": "[%key:common::config_flow::abort::reconfigure_successful%]"
}
} }
}, },
"services": { "services": {

View File

@ -7,6 +7,7 @@ import pytest
from homeassistant.components.google_generative_ai_conversation.const import ( from homeassistant.components.google_generative_ai_conversation.const import (
CONF_USE_GOOGLE_SEARCH_TOOL, CONF_USE_GOOGLE_SEARCH_TOOL,
DEFAULT_AI_TASK_NAME,
DEFAULT_CONVERSATION_NAME, DEFAULT_CONVERSATION_NAME,
DEFAULT_TTS_NAME, DEFAULT_TTS_NAME,
) )
@ -29,6 +30,7 @@ def mock_config_entry(hass: HomeAssistant) -> MockConfigEntry:
"api_key": "bla", "api_key": "bla",
}, },
version=2, version=2,
minor_version=3,
subentries_data=[ subentries_data=[
{ {
"data": {}, "data": {},
@ -44,6 +46,13 @@ def mock_config_entry(hass: HomeAssistant) -> MockConfigEntry:
"subentry_id": "ulid-tts", "subentry_id": "ulid-tts",
"unique_id": None, "unique_id": None,
}, },
{
"data": {},
"subentry_type": "ai_task_data",
"title": DEFAULT_AI_TASK_NAME,
"subentry_id": "ulid-ai-task",
"unique_id": None,
},
], ],
) )
entry.runtime_data = Mock() entry.runtime_data = Mock()

View File

@ -7,6 +7,14 @@
'options': dict({ 'options': dict({
}), }),
'subentries': dict({ 'subentries': dict({
'ulid-ai-task': dict({
'data': dict({
}),
'subentry_id': 'ulid-ai-task',
'subentry_type': 'ai_task_data',
'title': 'Google AI Task',
'unique_id': None,
}),
'ulid-conversation': dict({ 'ulid-conversation': dict({
'data': dict({ 'data': dict({
'chat_model': 'models/gemini-2.5-flash', 'chat_model': 'models/gemini-2.5-flash',

View File

@ -32,6 +32,37 @@
'sw_version': None, 'sw_version': None,
'via_device_id': None, 'via_device_id': None,
}), }),
DeviceRegistryEntrySnapshot({
'area_id': None,
'config_entries': <ANY>,
'config_entries_subentries': <ANY>,
'configuration_url': None,
'connections': set({
}),
'disabled_by': None,
'entry_type': <DeviceEntryType.SERVICE: 'service'>,
'hw_version': None,
'id': <ANY>,
'identifiers': set({
tuple(
'google_generative_ai_conversation',
'ulid-ai-task',
),
}),
'is_new': False,
'labels': set({
}),
'manufacturer': 'Google',
'model': 'gemini-2.5-flash',
'model_id': None,
'name': 'Google AI Task',
'name_by_user': None,
'primary_config_entry': <ANY>,
'serial_number': None,
'suggested_area': None,
'sw_version': None,
'via_device_id': None,
}),
DeviceRegistryEntrySnapshot({ DeviceRegistryEntrySnapshot({
'area_id': None, 'area_id': None,
'config_entries': <ANY>, 'config_entries': <ANY>,

View File

@ -0,0 +1,62 @@
"""Test AI Task platform of Google Generative AI Conversation integration."""
from unittest.mock import AsyncMock
from google.genai.types import GenerateContentResponse
import pytest
from homeassistant.components import ai_task
from homeassistant.core import HomeAssistant
from homeassistant.helpers import entity_registry as er
from tests.common import MockConfigEntry
from tests.components.conversation import (
MockChatLog,
mock_chat_log, # noqa: F401
)
@pytest.mark.usefixtures("mock_init_component")
async def test_run_task(
hass: HomeAssistant,
mock_config_entry: MockConfigEntry,
mock_chat_log: MockChatLog, # noqa: F811
mock_send_message_stream: AsyncMock,
entity_registry: er.EntityRegistry,
) -> None:
"""Test empty response."""
entity_id = "ai_task.google_ai_task"
# Ensure it's linked to the subentry
entity_entry = entity_registry.async_get(entity_id)
ai_task_entry = next(
iter(
entry
for entry in mock_config_entry.subentries.values()
if entry.subentry_type == "ai_task_data"
)
)
assert entity_entry.config_entry_id == mock_config_entry.entry_id
assert entity_entry.config_subentry_id == ai_task_entry.subentry_id
mock_send_message_stream.return_value = [
[
GenerateContentResponse(
candidates=[
{
"content": {
"parts": [{"text": "Hi there!"}],
"role": "model",
},
}
],
),
],
]
result = await ai_task.async_generate_data(
hass,
task_name="Test Task",
entity_id=entity_id,
instructions="Test prompt",
)
assert result.data == "Hi there!"

View File

@ -19,9 +19,11 @@ from homeassistant.components.google_generative_ai_conversation.const import (
CONF_TOP_K, CONF_TOP_K,
CONF_TOP_P, CONF_TOP_P,
CONF_USE_GOOGLE_SEARCH_TOOL, CONF_USE_GOOGLE_SEARCH_TOOL,
DEFAULT_AI_TASK_NAME,
DEFAULT_CONVERSATION_NAME, DEFAULT_CONVERSATION_NAME,
DEFAULT_TTS_NAME, DEFAULT_TTS_NAME,
DOMAIN, DOMAIN,
RECOMMENDED_AI_TASK_OPTIONS,
RECOMMENDED_CHAT_MODEL, RECOMMENDED_CHAT_MODEL,
RECOMMENDED_CONVERSATION_OPTIONS, RECOMMENDED_CONVERSATION_OPTIONS,
RECOMMENDED_HARM_BLOCK_THRESHOLD, RECOMMENDED_HARM_BLOCK_THRESHOLD,
@ -121,6 +123,12 @@ async def test_form(hass: HomeAssistant) -> None:
"title": DEFAULT_TTS_NAME, "title": DEFAULT_TTS_NAME,
"unique_id": None, "unique_id": None,
}, },
{
"subentry_type": "ai_task_data",
"data": RECOMMENDED_AI_TASK_OPTIONS,
"title": DEFAULT_AI_TASK_NAME,
"unique_id": None,
},
] ]
assert len(mock_setup_entry.mock_calls) == 1 assert len(mock_setup_entry.mock_calls) == 1
@ -222,7 +230,7 @@ async def test_creating_tts_subentry(
assert result2["title"] == "Mock TTS" assert result2["title"] == "Mock TTS"
assert result2["data"] == RECOMMENDED_TTS_OPTIONS assert result2["data"] == RECOMMENDED_TTS_OPTIONS
assert len(mock_config_entry.subentries) == 3 assert len(mock_config_entry.subentries) == 4
new_subentry_id = list(set(mock_config_entry.subentries) - old_subentries)[0] new_subentry_id = list(set(mock_config_entry.subentries) - old_subentries)[0]
new_subentry = mock_config_entry.subentries[new_subentry_id] new_subentry = mock_config_entry.subentries[new_subentry_id]
@ -232,13 +240,59 @@ async def test_creating_tts_subentry(
assert new_subentry.title == "Mock TTS" assert new_subentry.title == "Mock TTS"
async def test_creating_ai_task_subentry(
hass: HomeAssistant,
mock_init_component: None,
mock_config_entry: MockConfigEntry,
) -> None:
"""Test creating an AI task subentry."""
with patch(
"google.genai.models.AsyncModels.list",
return_value=get_models_pager(),
):
result = await hass.config_entries.subentries.async_init(
(mock_config_entry.entry_id, "ai_task_data"),
context={"source": config_entries.SOURCE_USER},
)
assert result["type"] is FlowResultType.FORM, result
assert result["step_id"] == "set_options"
assert not result["errors"]
old_subentries = set(mock_config_entry.subentries)
with patch(
"google.genai.models.AsyncModels.list",
return_value=get_models_pager(),
):
result2 = await hass.config_entries.subentries.async_configure(
result["flow_id"],
{CONF_NAME: "Mock AI Task", **RECOMMENDED_AI_TASK_OPTIONS},
)
await hass.async_block_till_done()
assert result2["type"] is FlowResultType.CREATE_ENTRY
assert result2["title"] == "Mock AI Task"
assert result2["data"] == RECOMMENDED_AI_TASK_OPTIONS
assert len(mock_config_entry.subentries) == 4
new_subentry_id = list(set(mock_config_entry.subentries) - old_subentries)[0]
new_subentry = mock_config_entry.subentries[new_subentry_id]
assert new_subentry.subentry_type == "ai_task_data"
assert new_subentry.data == RECOMMENDED_AI_TASK_OPTIONS
assert new_subentry.title == "Mock AI Task"
async def test_creating_conversation_subentry_not_loaded( async def test_creating_conversation_subentry_not_loaded(
hass: HomeAssistant, hass: HomeAssistant,
mock_init_component: None, mock_init_component: None,
mock_config_entry: MockConfigEntry, mock_config_entry: MockConfigEntry,
) -> None: ) -> None:
"""Test creating a conversation subentry.""" """Test that subentry fails to init if entry not loaded."""
await hass.config_entries.async_unload(mock_config_entry.entry_id) await hass.config_entries.async_unload(mock_config_entry.entry_id)
with patch( with patch(
"google.genai.models.AsyncModels.list", "google.genai.models.AsyncModels.list",
return_value=get_models_pager(), return_value=get_models_pager(),

View File

@ -8,9 +8,13 @@ from requests.exceptions import Timeout
from syrupy.assertion import SnapshotAssertion from syrupy.assertion import SnapshotAssertion
from homeassistant.components.google_generative_ai_conversation.const import ( from homeassistant.components.google_generative_ai_conversation.const import (
DEFAULT_AI_TASK_NAME,
DEFAULT_CONVERSATION_NAME,
DEFAULT_TITLE, DEFAULT_TITLE,
DEFAULT_TTS_NAME, DEFAULT_TTS_NAME,
DOMAIN, DOMAIN,
RECOMMENDED_AI_TASK_OPTIONS,
RECOMMENDED_CONVERSATION_OPTIONS,
RECOMMENDED_TTS_OPTIONS, RECOMMENDED_TTS_OPTIONS,
) )
from homeassistant.config_entries import ConfigEntryState, ConfigSubentryData from homeassistant.config_entries import ConfigEntryState, ConfigSubentryData
@ -397,7 +401,7 @@ async def test_load_entry_with_unloaded_entries(
assert [tuple(mock_call) for mock_call in mock_generate.mock_calls] == snapshot assert [tuple(mock_call) for mock_call in mock_generate.mock_calls] == snapshot
async def test_migration_from_v1_to_v2( async def test_migration_from_v1(
hass: HomeAssistant, hass: HomeAssistant,
device_registry: dr.DeviceRegistry, device_registry: dr.DeviceRegistry,
entity_registry: er.EntityRegistry, entity_registry: er.EntityRegistry,
@ -473,10 +477,10 @@ async def test_migration_from_v1_to_v2(
assert len(entries) == 1 assert len(entries) == 1
entry = entries[0] entry = entries[0]
assert entry.version == 2 assert entry.version == 2
assert entry.minor_version == 2 assert entry.minor_version == 3
assert not entry.options assert not entry.options
assert entry.title == DEFAULT_TITLE assert entry.title == DEFAULT_TITLE
assert len(entry.subentries) == 3 assert len(entry.subentries) == 4
conversation_subentries = [ conversation_subentries = [
subentry subentry
for subentry in entry.subentries.values() for subentry in entry.subentries.values()
@ -495,6 +499,14 @@ async def test_migration_from_v1_to_v2(
assert len(tts_subentries) == 1 assert len(tts_subentries) == 1
assert tts_subentries[0].data == RECOMMENDED_TTS_OPTIONS assert tts_subentries[0].data == RECOMMENDED_TTS_OPTIONS
assert tts_subentries[0].title == DEFAULT_TTS_NAME assert tts_subentries[0].title == DEFAULT_TTS_NAME
ai_task_subentries = [
subentry
for subentry in entry.subentries.values()
if subentry.subentry_type == "ai_task_data"
]
assert len(ai_task_subentries) == 1
assert ai_task_subentries[0].data == RECOMMENDED_AI_TASK_OPTIONS
assert ai_task_subentries[0].title == DEFAULT_AI_TASK_NAME
subentry = conversation_subentries[0] subentry = conversation_subentries[0]
@ -542,7 +554,7 @@ async def test_migration_from_v1_to_v2(
} }
async def test_migration_from_v1_to_v2_with_multiple_keys( async def test_migration_from_v1_with_multiple_keys(
hass: HomeAssistant, hass: HomeAssistant,
device_registry: dr.DeviceRegistry, device_registry: dr.DeviceRegistry,
entity_registry: er.EntityRegistry, entity_registry: er.EntityRegistry,
@ -619,10 +631,10 @@ async def test_migration_from_v1_to_v2_with_multiple_keys(
for entry in entries: for entry in entries:
assert entry.version == 2 assert entry.version == 2
assert entry.minor_version == 2 assert entry.minor_version == 3
assert not entry.options assert not entry.options
assert entry.title == DEFAULT_TITLE assert entry.title == DEFAULT_TITLE
assert len(entry.subentries) == 2 assert len(entry.subentries) == 3
subentry = list(entry.subentries.values())[0] subentry = list(entry.subentries.values())[0]
assert subentry.subentry_type == "conversation" assert subentry.subentry_type == "conversation"
assert subentry.data == options assert subentry.data == options
@ -631,6 +643,10 @@ async def test_migration_from_v1_to_v2_with_multiple_keys(
assert subentry.subentry_type == "tts" assert subentry.subentry_type == "tts"
assert subentry.data == RECOMMENDED_TTS_OPTIONS assert subentry.data == RECOMMENDED_TTS_OPTIONS
assert subentry.title == DEFAULT_TTS_NAME assert subentry.title == DEFAULT_TTS_NAME
subentry = list(entry.subentries.values())[2]
assert subentry.subentry_type == "ai_task_data"
assert subentry.data == RECOMMENDED_AI_TASK_OPTIONS
assert subentry.title == DEFAULT_AI_TASK_NAME
dev = device_registry.async_get_device( dev = device_registry.async_get_device(
identifiers={(DOMAIN, list(entry.subentries.values())[0].subentry_id)} identifiers={(DOMAIN, list(entry.subentries.values())[0].subentry_id)}
@ -642,7 +658,7 @@ async def test_migration_from_v1_to_v2_with_multiple_keys(
} }
async def test_migration_from_v1_to_v2_with_same_keys( async def test_migration_from_v1_with_same_keys(
hass: HomeAssistant, hass: HomeAssistant,
device_registry: dr.DeviceRegistry, device_registry: dr.DeviceRegistry,
entity_registry: er.EntityRegistry, entity_registry: er.EntityRegistry,
@ -718,10 +734,10 @@ async def test_migration_from_v1_to_v2_with_same_keys(
assert len(entries) == 1 assert len(entries) == 1
entry = entries[0] entry = entries[0]
assert entry.version == 2 assert entry.version == 2
assert entry.minor_version == 2 assert entry.minor_version == 3
assert not entry.options assert not entry.options
assert entry.title == DEFAULT_TITLE assert entry.title == DEFAULT_TITLE
assert len(entry.subentries) == 3 assert len(entry.subentries) == 4
conversation_subentries = [ conversation_subentries = [
subentry subentry
for subentry in entry.subentries.values() for subentry in entry.subentries.values()
@ -740,6 +756,14 @@ async def test_migration_from_v1_to_v2_with_same_keys(
assert len(tts_subentries) == 1 assert len(tts_subentries) == 1
assert tts_subentries[0].data == RECOMMENDED_TTS_OPTIONS assert tts_subentries[0].data == RECOMMENDED_TTS_OPTIONS
assert tts_subentries[0].title == DEFAULT_TTS_NAME assert tts_subentries[0].title == DEFAULT_TTS_NAME
ai_task_subentries = [
subentry
for subentry in entry.subentries.values()
if subentry.subentry_type == "ai_task_data"
]
assert len(ai_task_subentries) == 1
assert ai_task_subentries[0].data == RECOMMENDED_AI_TASK_OPTIONS
assert ai_task_subentries[0].title == DEFAULT_AI_TASK_NAME
subentry = conversation_subentries[0] subentry = conversation_subentries[0]
@ -829,7 +853,7 @@ async def test_migration_from_v1_to_v2_with_same_keys(
), ),
], ],
) )
async def test_migration_from_v2_1_to_v2_2( async def test_migration_from_v2_1(
hass: HomeAssistant, hass: HomeAssistant,
device_registry: dr.DeviceRegistry, device_registry: dr.DeviceRegistry,
entity_registry: er.EntityRegistry, entity_registry: er.EntityRegistry,
@ -837,12 +861,13 @@ async def test_migration_from_v2_1_to_v2_2(
extra_subentries: list[ConfigSubentryData], extra_subentries: list[ConfigSubentryData],
expected_device_subentries: dict[str, set[str | None]], expected_device_subentries: dict[str, set[str | None]],
) -> None: ) -> None:
"""Test migration from version 2.1 to version 2.2. """Test migration from version 2.1.
This tests we clean up the broken migration in Home Assistant Core This tests we clean up the broken migration in Home Assistant Core
2025.7.0b0-2025.7.0b1: 2025.7.0b0-2025.7.0b1 and add AI Task subentry:
- Fix device registry (Fixed in Home Assistant Core 2025.7.0b2) - Fix device registry (Fixed in Home Assistant Core 2025.7.0b2)
- Add TTS subentry (Added in Home Assistant Core 2025.7.0b1) - Add TTS subentry (Added in Home Assistant Core 2025.7.0b1)
- Add AI Task subentry (Added in version 2.3)
""" """
# Create a v2.1 config entry with 2 subentries, devices and entities # Create a v2.1 config entry with 2 subentries, devices and entities
options = { options = {
@ -930,10 +955,10 @@ async def test_migration_from_v2_1_to_v2_2(
assert len(entries) == 1 assert len(entries) == 1
entry = entries[0] entry = entries[0]
assert entry.version == 2 assert entry.version == 2
assert entry.minor_version == 2 assert entry.minor_version == 3
assert not entry.options assert not entry.options
assert entry.title == DEFAULT_TITLE assert entry.title == DEFAULT_TITLE
assert len(entry.subentries) == 3 assert len(entry.subentries) == 4
conversation_subentries = [ conversation_subentries = [
subentry subentry
for subentry in entry.subentries.values() for subentry in entry.subentries.values()
@ -952,6 +977,14 @@ async def test_migration_from_v2_1_to_v2_2(
assert len(tts_subentries) == 1 assert len(tts_subentries) == 1
assert tts_subentries[0].data == RECOMMENDED_TTS_OPTIONS assert tts_subentries[0].data == RECOMMENDED_TTS_OPTIONS
assert tts_subentries[0].title == DEFAULT_TTS_NAME assert tts_subentries[0].title == DEFAULT_TTS_NAME
ai_task_subentries = [
subentry
for subentry in entry.subentries.values()
if subentry.subentry_type == "ai_task_data"
]
assert len(ai_task_subentries) == 1
assert ai_task_subentries[0].data == RECOMMENDED_AI_TASK_OPTIONS
assert ai_task_subentries[0].title == DEFAULT_AI_TASK_NAME
subentry = conversation_subentries[0] subentry = conversation_subentries[0]
@ -1011,3 +1044,80 @@ async def test_devices(
device_registry, mock_config_entry.entry_id device_registry, mock_config_entry.entry_id
) )
assert devices == snapshot assert devices == snapshot
async def test_migrate_entry_from_v2_2(hass: HomeAssistant) -> None:
"""Test migration from version 2.2."""
# Create a v2.2 config entry with conversation and TTS subentries
mock_config_entry = MockConfigEntry(
domain=DOMAIN,
data={CONF_API_KEY: "test-api-key"},
version=2,
minor_version=2,
subentries_data=[
{
"data": RECOMMENDED_CONVERSATION_OPTIONS,
"subentry_type": "conversation",
"title": DEFAULT_CONVERSATION_NAME,
"unique_id": None,
},
{
"data": RECOMMENDED_TTS_OPTIONS,
"subentry_type": "tts",
"title": DEFAULT_TTS_NAME,
"unique_id": None,
},
],
)
mock_config_entry.add_to_hass(hass)
# Verify initial state
assert mock_config_entry.version == 2
assert mock_config_entry.minor_version == 2
assert len(mock_config_entry.subentries) == 2
# Run setup to trigger migration
with patch(
"homeassistant.components.google_generative_ai_conversation.async_setup_entry",
return_value=True,
):
result = await hass.config_entries.async_setup(mock_config_entry.entry_id)
assert result is True
await hass.async_block_till_done()
# Verify migration completed
entries = hass.config_entries.async_entries(DOMAIN)
assert len(entries) == 1
entry = entries[0]
# Check version and subversion were updated
assert entry.version == 2
assert entry.minor_version == 3
# Check we now have conversation, tts and ai_task_data subentries
assert len(entry.subentries) == 3
subentries = {
subentry.subentry_type: subentry for subentry in entry.subentries.values()
}
assert "conversation" in subentries
assert "tts" in subentries
assert "ai_task_data" in subentries
# Find and verify the ai_task_data subentry
ai_task_subentry = subentries["ai_task_data"]
assert ai_task_subentry is not None
assert ai_task_subentry.title == DEFAULT_AI_TASK_NAME
assert ai_task_subentry.data == RECOMMENDED_AI_TASK_OPTIONS
# Verify conversation subentry is still there and unchanged
conversation_subentry = subentries["conversation"]
assert conversation_subentry is not None
assert conversation_subentry.title == DEFAULT_CONVERSATION_NAME
assert conversation_subentry.data == RECOMMENDED_CONVERSATION_OPTIONS
# Verify TTS subentry is still there and unchanged
tts_subentry = subentries["tts"]
assert tts_subentry is not None
assert tts_subentry.title == DEFAULT_TTS_NAME
assert tts_subentry.data == RECOMMENDED_TTS_OPTIONS