From 56f4039ac26fd39945dfe948c4c314237463f688 Mon Sep 17 00:00:00 2001 From: Paulus Schoutsen Date: Mon, 23 Jun 2025 20:59:32 -0400 Subject: [PATCH] Migrate Google Gen AI to use subentries (#147281) * Migrate Google Gen AI to use subentries * Add reconfig successful msg * Address comments * Do not allow addin subentry when not loaded * Let HA do the migration * Use config_entries.async_setup * Remove fallback name on base entity * Fix * Fix * Fix device name assignment in entity and tts modules * Fix tests --------- Co-authored-by: Joostlek --- .../__init__.py | 80 ++++++- .../config_flow.py | 148 ++++++++---- .../const.py | 2 + .../conversation.py | 20 +- .../entity.py | 17 +- .../strings.json | 64 ++++-- .../google_generative_ai_conversation/tts.py | 1 - .../conftest.py | 23 +- .../test_config_flow.py | 122 ++++++++-- .../test_conversation.py | 26 ++- .../test_init.py | 214 ++++++++++++++++++ .../test_tts.py | 4 +- 12 files changed, 599 insertions(+), 122 deletions(-) diff --git a/homeassistant/components/google_generative_ai_conversation/__init__.py b/homeassistant/components/google_generative_ai_conversation/__init__.py index 7e9ca550275..4830e204654 100644 --- a/homeassistant/components/google_generative_ai_conversation/__init__.py +++ b/homeassistant/components/google_generative_ai_conversation/__init__.py @@ -12,7 +12,7 @@ from google.genai.types import File, FileState from requests.exceptions import Timeout import voluptuous as vol -from homeassistant.config_entries import ConfigEntry +from homeassistant.config_entries import ConfigEntry, ConfigSubentry from homeassistant.const import CONF_API_KEY, Platform from homeassistant.core import ( HomeAssistant, @@ -26,7 +26,11 @@ from homeassistant.exceptions import ( ConfigEntryNotReady, HomeAssistantError, ) -from homeassistant.helpers import config_validation as cv +from homeassistant.helpers import ( + config_validation as cv, + device_registry as dr, + entity_registry as er, +) from homeassistant.helpers.issue_registry import IssueSeverity, async_create_issue from homeassistant.helpers.typing import ConfigType @@ -56,6 +60,8 @@ type GoogleGenerativeAIConfigEntry = ConfigEntry[Client] async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool: """Set up Google Generative AI Conversation.""" + await async_migrate_integration(hass) + async def generate_content(call: ServiceCall) -> ServiceResponse: """Generate content from text and optionally images.""" @@ -209,3 +215,73 @@ async def async_unload_entry( return False return True + + +async def async_migrate_integration(hass: HomeAssistant) -> None: + """Migrate integration entry structure.""" + + entries = hass.config_entries.async_entries(DOMAIN) + if not any(entry.version == 1 for entry in entries): + return + + api_keys_entries: dict[str, ConfigEntry] = {} + entity_registry = er.async_get(hass) + device_registry = dr.async_get(hass) + + for entry in entries: + use_existing = False + subentry = ConfigSubentry( + data=entry.options, + subentry_type="conversation", + title=entry.title, + unique_id=None, + ) + if entry.data[CONF_API_KEY] not in api_keys_entries: + use_existing = True + api_keys_entries[entry.data[CONF_API_KEY]] = entry + + parent_entry = api_keys_entries[entry.data[CONF_API_KEY]] + + hass.config_entries.async_add_subentry(parent_entry, subentry) + conversation_entity = entity_registry.async_get_entity_id( + "conversation", + DOMAIN, + entry.entry_id, + ) + if conversation_entity is not None: + entity_registry.async_update_entity( + conversation_entity, + config_entry_id=parent_entry.entry_id, + config_subentry_id=subentry.subentry_id, + new_unique_id=subentry.subentry_id, + ) + + device = device_registry.async_get_device( + identifiers={(DOMAIN, entry.entry_id)} + ) + if device is not None: + device_registry.async_update_device( + device.id, + new_identifiers={(DOMAIN, subentry.subentry_id)}, + add_config_subentry_id=subentry.subentry_id, + add_config_entry_id=parent_entry.entry_id, + ) + if parent_entry.entry_id != entry.entry_id: + device_registry.async_update_device( + device.id, + add_config_subentry_id=subentry.subentry_id, + add_config_entry_id=parent_entry.entry_id, + ) + device_registry.async_update_device( + device.id, + remove_config_entry_id=entry.entry_id, + ) + + if not use_existing: + await hass.config_entries.async_remove(entry.entry_id) + else: + hass.config_entries.async_update_entry( + entry, + options={}, + version=2, + ) diff --git a/homeassistant/components/google_generative_ai_conversation/config_flow.py b/homeassistant/components/google_generative_ai_conversation/config_flow.py index ae0f09b1037..4b7c7a0dd47 100644 --- a/homeassistant/components/google_generative_ai_conversation/config_flow.py +++ b/homeassistant/components/google_generative_ai_conversation/config_flow.py @@ -4,8 +4,7 @@ from __future__ import annotations from collections.abc import Mapping import logging -from types import MappingProxyType -from typing import Any +from typing import Any, cast from google import genai from google.genai.errors import APIError, ClientError @@ -15,12 +14,14 @@ import voluptuous as vol from homeassistant.config_entries import ( SOURCE_REAUTH, ConfigEntry, + ConfigEntryState, ConfigFlow, ConfigFlowResult, - OptionsFlow, + ConfigSubentryFlow, + SubentryFlowResult, ) from homeassistant.const import CONF_API_KEY, CONF_LLM_HASS_API, CONF_NAME -from homeassistant.core import HomeAssistant +from homeassistant.core import HomeAssistant, callback from homeassistant.helpers import llm from homeassistant.helpers.selector import ( NumberSelector, @@ -45,6 +46,7 @@ from .const import ( CONF_TOP_K, CONF_TOP_P, CONF_USE_GOOGLE_SEARCH_TOOL, + DEFAULT_CONVERSATION_NAME, DOMAIN, RECOMMENDED_CHAT_MODEL, RECOMMENDED_HARM_BLOCK_THRESHOLD, @@ -66,7 +68,7 @@ STEP_API_DATA_SCHEMA = vol.Schema( RECOMMENDED_OPTIONS = { CONF_RECOMMENDED: True, - CONF_LLM_HASS_API: llm.LLM_API_ASSIST, + CONF_LLM_HASS_API: [llm.LLM_API_ASSIST], CONF_PROMPT: llm.DEFAULT_INSTRUCTIONS_PROMPT, } @@ -90,7 +92,7 @@ async def validate_input(data: dict[str, Any]) -> None: class GoogleGenerativeAIConfigFlow(ConfigFlow, domain=DOMAIN): """Handle a config flow for Google Generative AI Conversation.""" - VERSION = 1 + VERSION = 2 async def async_step_api( self, user_input: dict[str, Any] | None = None @@ -98,6 +100,7 @@ class GoogleGenerativeAIConfigFlow(ConfigFlow, domain=DOMAIN): """Handle the initial step.""" errors: dict[str, str] = {} if user_input is not None: + self._async_abort_entries_match(user_input) try: await validate_input(user_input) except (APIError, Timeout) as err: @@ -117,7 +120,14 @@ class GoogleGenerativeAIConfigFlow(ConfigFlow, domain=DOMAIN): return self.async_create_entry( title="Google Generative AI", data=user_input, - options=RECOMMENDED_OPTIONS, + subentries=[ + { + "subentry_type": "conversation", + "data": RECOMMENDED_OPTIONS, + "title": DEFAULT_CONVERSATION_NAME, + "unique_id": None, + } + ], ) return self.async_show_form( step_id="api", @@ -156,41 +166,72 @@ class GoogleGenerativeAIConfigFlow(ConfigFlow, domain=DOMAIN): }, ) - @staticmethod - def async_get_options_flow( - config_entry: ConfigEntry, - ) -> OptionsFlow: - """Create the options flow.""" - return GoogleGenerativeAIOptionsFlow(config_entry) + @classmethod + @callback + def async_get_supported_subentry_types( + cls, config_entry: ConfigEntry + ) -> dict[str, type[ConfigSubentryFlow]]: + """Return subentries supported by this integration.""" + return {"conversation": ConversationSubentryFlowHandler} -class GoogleGenerativeAIOptionsFlow(OptionsFlow): - """Google Generative AI config flow options handler.""" +class ConversationSubentryFlowHandler(ConfigSubentryFlow): + """Flow for managing conversation subentries.""" - def __init__(self, config_entry: ConfigEntry) -> None: - """Initialize options flow.""" - self.last_rendered_recommended = config_entry.options.get( - CONF_RECOMMENDED, False - ) - self._genai_client = config_entry.runtime_data + last_rendered_recommended = False - async def async_step_init( + @property + def _genai_client(self) -> genai.Client: + """Return the Google Generative AI client.""" + return self._get_entry().runtime_data + + @property + def _is_new(self) -> bool: + """Return if this is a new subentry.""" + return self.source == "user" + + async def async_step_set_options( self, user_input: dict[str, Any] | None = None - ) -> ConfigFlowResult: - """Manage the options.""" - options: dict[str, Any] | MappingProxyType[str, Any] = self.config_entry.options + ) -> SubentryFlowResult: + """Set conversation options.""" + # abort if entry is not loaded + if self._get_entry().state != ConfigEntryState.LOADED: + return self.async_abort(reason="entry_not_loaded") + errors: dict[str, str] = {} - if user_input is not None: + if user_input is None: + if self._is_new: + options = RECOMMENDED_OPTIONS.copy() + else: + # If this is a reconfiguration, we need to copy the existing options + # so that we can show the current values in the form. + options = self._get_reconfigure_subentry().data.copy() + + self.last_rendered_recommended = cast( + bool, options.get(CONF_RECOMMENDED, False) + ) + + else: if user_input[CONF_RECOMMENDED] == self.last_rendered_recommended: if not user_input.get(CONF_LLM_HASS_API): user_input.pop(CONF_LLM_HASS_API, None) + # Don't allow to save options that enable the Google Seearch tool with an Assist API if not ( user_input.get(CONF_LLM_HASS_API) and user_input.get(CONF_USE_GOOGLE_SEARCH_TOOL, False) is True ): - # Don't allow to save options that enable the Google Seearch tool with an Assist API - return self.async_create_entry(title="", data=user_input) + if self._is_new: + return self.async_create_entry( + title=user_input.pop(CONF_NAME), + data=user_input, + ) + + return self.async_update_and_abort( + self._get_entry(), + self._get_reconfigure_subentry(), + data=user_input, + ) errors[CONF_USE_GOOGLE_SEARCH_TOOL] = "invalid_google_search_option" # Re-render the options again, now with the recommended options shown/hidden @@ -199,15 +240,19 @@ class GoogleGenerativeAIOptionsFlow(OptionsFlow): options = user_input schema = await google_generative_ai_config_option_schema( - self.hass, options, self._genai_client + self.hass, self._is_new, options, self._genai_client ) return self.async_show_form( - step_id="init", data_schema=vol.Schema(schema), errors=errors + step_id="set_options", data_schema=vol.Schema(schema), errors=errors ) + async_step_reconfigure = async_step_set_options + async_step_user = async_step_set_options + async def google_generative_ai_config_option_schema( hass: HomeAssistant, + is_new: bool, options: Mapping[str, Any], genai_client: genai.Client, ) -> dict: @@ -224,23 +269,32 @@ async def google_generative_ai_config_option_schema( ): suggested_llm_apis = [suggested_llm_apis] - schema = { - vol.Optional( - CONF_PROMPT, - description={ - "suggested_value": options.get( - CONF_PROMPT, llm.DEFAULT_INSTRUCTIONS_PROMPT - ) - }, - ): TemplateSelector(), - vol.Optional( - CONF_LLM_HASS_API, - description={"suggested_value": suggested_llm_apis}, - ): SelectSelector(SelectSelectorConfig(options=hass_apis, multiple=True)), - vol.Required( - CONF_RECOMMENDED, default=options.get(CONF_RECOMMENDED, False) - ): bool, - } + if is_new: + schema: dict[vol.Required | vol.Optional, Any] = { + vol.Required(CONF_NAME, default=DEFAULT_CONVERSATION_NAME): str, + } + else: + schema = {} + + schema.update( + { + vol.Optional( + CONF_PROMPT, + description={ + "suggested_value": options.get( + CONF_PROMPT, llm.DEFAULT_INSTRUCTIONS_PROMPT + ) + }, + ): TemplateSelector(), + vol.Optional( + CONF_LLM_HASS_API, + description={"suggested_value": suggested_llm_apis}, + ): SelectSelector(SelectSelectorConfig(options=hass_apis, multiple=True)), + vol.Required( + CONF_RECOMMENDED, default=options.get(CONF_RECOMMENDED, False) + ): bool, + } + ) if options.get(CONF_RECOMMENDED): return schema diff --git a/homeassistant/components/google_generative_ai_conversation/const.py b/homeassistant/components/google_generative_ai_conversation/const.py index 7e699d7c8c0..0735e9015c2 100644 --- a/homeassistant/components/google_generative_ai_conversation/const.py +++ b/homeassistant/components/google_generative_ai_conversation/const.py @@ -6,6 +6,8 @@ DOMAIN = "google_generative_ai_conversation" LOGGER = logging.getLogger(__package__) CONF_PROMPT = "prompt" +DEFAULT_CONVERSATION_NAME = "Google AI Conversation" + ATTR_MODEL = "model" CONF_RECOMMENDED = "recommended" CONF_CHAT_MODEL = "chat_model" diff --git a/homeassistant/components/google_generative_ai_conversation/conversation.py b/homeassistant/components/google_generative_ai_conversation/conversation.py index 00199f5fe1f..d8eae3f6d0d 100644 --- a/homeassistant/components/google_generative_ai_conversation/conversation.py +++ b/homeassistant/components/google_generative_ai_conversation/conversation.py @@ -5,7 +5,7 @@ from __future__ import annotations from typing import Literal from homeassistant.components import assist_pipeline, conversation -from homeassistant.config_entries import ConfigEntry +from homeassistant.config_entries import ConfigEntry, ConfigSubentry from homeassistant.const import CONF_LLM_HASS_API, MATCH_ALL from homeassistant.core import HomeAssistant from homeassistant.exceptions import HomeAssistantError @@ -22,8 +22,14 @@ async def async_setup_entry( async_add_entities: AddConfigEntryEntitiesCallback, ) -> None: """Set up conversation entities.""" - agent = GoogleGenerativeAIConversationEntity(config_entry) - async_add_entities([agent]) + for subentry in config_entry.subentries.values(): + if subentry.subentry_type != "conversation": + continue + + async_add_entities( + [GoogleGenerativeAIConversationEntity(config_entry, subentry)], + config_subentry_id=subentry.subentry_id, + ) class GoogleGenerativeAIConversationEntity( @@ -35,10 +41,10 @@ class GoogleGenerativeAIConversationEntity( _attr_supports_streaming = True - def __init__(self, entry: ConfigEntry) -> None: + def __init__(self, entry: ConfigEntry, subentry: ConfigSubentry) -> None: """Initialize the agent.""" - super().__init__(entry) - if self.entry.options.get(CONF_LLM_HASS_API): + super().__init__(entry, subentry) + if self.subentry.data.get(CONF_LLM_HASS_API): self._attr_supported_features = ( conversation.ConversationEntityFeature.CONTROL ) @@ -70,7 +76,7 @@ class GoogleGenerativeAIConversationEntity( chat_log: conversation.ChatLog, ) -> conversation.ConversationResult: """Call the API.""" - options = self.entry.options + options = self.subentry.data try: await chat_log.async_provide_llm_data( diff --git a/homeassistant/components/google_generative_ai_conversation/entity.py b/homeassistant/components/google_generative_ai_conversation/entity.py index 7eef3dbacff..d4b0ec2bbd0 100644 --- a/homeassistant/components/google_generative_ai_conversation/entity.py +++ b/homeassistant/components/google_generative_ai_conversation/entity.py @@ -24,7 +24,7 @@ from google.genai.types import ( from voluptuous_openapi import convert from homeassistant.components import conversation -from homeassistant.config_entries import ConfigEntry +from homeassistant.config_entries import ConfigEntry, ConfigSubentry from homeassistant.exceptions import HomeAssistantError from homeassistant.helpers import device_registry as dr, llm from homeassistant.helpers.entity import Entity @@ -301,17 +301,16 @@ async def _transform_stream( class GoogleGenerativeAILLMBaseEntity(Entity): """Google Generative AI base entity.""" - _attr_has_entity_name = True - _attr_name = None - - def __init__(self, entry: ConfigEntry) -> None: + def __init__(self, entry: ConfigEntry, subentry: ConfigSubentry) -> None: """Initialize the agent.""" self.entry = entry + self.subentry = subentry + self._attr_name = subentry.title self._genai_client = entry.runtime_data - self._attr_unique_id = entry.entry_id + self._attr_unique_id = subentry.subentry_id self._attr_device_info = dr.DeviceInfo( - identifiers={(DOMAIN, entry.entry_id)}, - name=entry.title, + identifiers={(DOMAIN, subentry.subentry_id)}, + name=subentry.title, manufacturer="Google", model="Generative AI", entry_type=dr.DeviceEntryType.SERVICE, @@ -322,7 +321,7 @@ class GoogleGenerativeAILLMBaseEntity(Entity): chat_log: conversation.ChatLog, ) -> None: """Generate an answer for the chat log.""" - options = self.entry.options + options = self.subentry.data tools: list[Tool | Callable[..., Any]] | None = None if chat_log.llm_api: diff --git a/homeassistant/components/google_generative_ai_conversation/strings.json b/homeassistant/components/google_generative_ai_conversation/strings.json index a57e2f78f53..e523aecbaec 100644 --- a/homeassistant/components/google_generative_ai_conversation/strings.json +++ b/homeassistant/components/google_generative_ai_conversation/strings.json @@ -18,35 +18,49 @@ "unknown": "[%key:common::config_flow::error::unknown%]" }, "abort": { - "reauth_successful": "[%key:common::config_flow::abort::reauth_successful%]" + "reauth_successful": "[%key:common::config_flow::abort::reauth_successful%]", + "already_configured": "[%key:common::config_flow::abort::already_configured_service%]" } }, - "options": { - "step": { - "init": { - "data": { - "recommended": "Recommended model settings", - "prompt": "Instructions", - "chat_model": "[%key:common::generic::model%]", - "temperature": "Temperature", - "top_p": "Top P", - "top_k": "Top K", - "max_tokens": "Maximum tokens to return in response", - "llm_hass_api": "[%key:common::config_flow::data::llm_hass_api%]", - "harassment_block_threshold": "Negative or harmful comments targeting identity and/or protected attributes", - "hate_block_threshold": "Content that is rude, disrespectful, or profane", - "sexual_block_threshold": "Contains references to sexual acts or other lewd content", - "dangerous_block_threshold": "Promotes, facilitates, or encourages harmful acts", - "enable_google_search_tool": "Enable Google Search tool" - }, - "data_description": { - "prompt": "Instruct how the LLM should respond. This can be a template.", - "enable_google_search_tool": "Only works if there is nothing selected in the \"Control Home Assistant\" setting. See docs for a workaround using it with \"Assist\"." + "config_subentries": { + "conversation": { + "initiate_flow": { + "user": "Add conversation agent", + "reconfigure": "Reconfigure conversation agent" + }, + "entry_type": "Conversation agent", + + "step": { + "set_options": { + "data": { + "name": "[%key:common::config_flow::data::name%]", + "recommended": "Recommended model settings", + "prompt": "Instructions", + "chat_model": "[%key:common::generic::model%]", + "temperature": "Temperature", + "top_p": "Top P", + "top_k": "Top K", + "max_tokens": "Maximum tokens to return in response", + "llm_hass_api": "[%key:common::config_flow::data::llm_hass_api%]", + "harassment_block_threshold": "Negative or harmful comments targeting identity and/or protected attributes", + "hate_block_threshold": "Content that is rude, disrespectful, or profane", + "sexual_block_threshold": "Contains references to sexual acts or other lewd content", + "dangerous_block_threshold": "Promotes, facilitates, or encourages harmful acts", + "enable_google_search_tool": "Enable Google Search tool" + }, + "data_description": { + "prompt": "Instruct how the LLM should respond. This can be a template.", + "enable_google_search_tool": "Only works if there is nothing selected in the \"Control Home Assistant\" setting. See docs for a workaround using it with \"Assist\"." + } } + }, + "abort": { + "entry_not_loaded": "Cannot add things while the configuration is disabled.", + "reconfigure_successful": "[%key:common::config_flow::abort::reconfigure_successful%]" + }, + "error": { + "invalid_google_search_option": "Google Search can only be enabled if nothing is selected in the \"Control Home Assistant\" setting." } - }, - "error": { - "invalid_google_search_option": "Google Search can only be enabled if nothing is selected in the \"Control Home Assistant\" setting." } }, "services": { diff --git a/homeassistant/components/google_generative_ai_conversation/tts.py b/homeassistant/components/google_generative_ai_conversation/tts.py index 160048e4897..50baec67db2 100644 --- a/homeassistant/components/google_generative_ai_conversation/tts.py +++ b/homeassistant/components/google_generative_ai_conversation/tts.py @@ -113,7 +113,6 @@ class GoogleGenerativeAITextToSpeechEntity(TextToSpeechEntity): self._attr_unique_id = f"{entry.entry_id}_tts" self._attr_device_info = dr.DeviceInfo( identifiers={(DOMAIN, entry.entry_id)}, - name=entry.title, manufacturer="Google", model="Generative AI", entry_type=dr.DeviceEntryType.SERVICE, diff --git a/tests/components/google_generative_ai_conversation/conftest.py b/tests/components/google_generative_ai_conversation/conftest.py index f499f18bc15..36d99cd2764 100644 --- a/tests/components/google_generative_ai_conversation/conftest.py +++ b/tests/components/google_generative_ai_conversation/conftest.py @@ -5,8 +5,9 @@ from unittest.mock import AsyncMock, Mock, patch import pytest -from homeassistant.components.google_generative_ai_conversation.entity import ( +from homeassistant.components.google_generative_ai_conversation.const import ( CONF_USE_GOOGLE_SEARCH_TOOL, + DEFAULT_CONVERSATION_NAME, ) from homeassistant.config_entries import ConfigEntry from homeassistant.const import CONF_LLM_HASS_API @@ -26,6 +27,15 @@ def mock_config_entry(hass: HomeAssistant) -> MockConfigEntry: data={ "api_key": "bla", }, + version=2, + subentries_data=[ + { + "data": {}, + "subentry_type": "conversation", + "title": DEFAULT_CONVERSATION_NAME, + "unique_id": None, + } + ], ) entry.runtime_data = Mock() entry.add_to_hass(hass) @@ -38,8 +48,10 @@ async def mock_config_entry_with_assist( ) -> MockConfigEntry: """Mock a config entry with assist.""" with patch("google.genai.models.AsyncModels.get"): - hass.config_entries.async_update_entry( - mock_config_entry, options={CONF_LLM_HASS_API: llm.LLM_API_ASSIST} + hass.config_entries.async_update_subentry( + mock_config_entry, + next(iter(mock_config_entry.subentries.values())), + data={CONF_LLM_HASS_API: llm.LLM_API_ASSIST}, ) await hass.async_block_till_done() return mock_config_entry @@ -51,9 +63,10 @@ async def mock_config_entry_with_google_search( ) -> MockConfigEntry: """Mock a config entry with assist.""" with patch("google.genai.models.AsyncModels.get"): - hass.config_entries.async_update_entry( + hass.config_entries.async_update_subentry( mock_config_entry, - options={ + next(iter(mock_config_entry.subentries.values())), + data={ CONF_LLM_HASS_API: llm.LLM_API_ASSIST, CONF_USE_GOOGLE_SEARCH_TOOL: True, }, diff --git a/tests/components/google_generative_ai_conversation/test_config_flow.py b/tests/components/google_generative_ai_conversation/test_config_flow.py index 0dc0996ad30..e02d85e41c4 100644 --- a/tests/components/google_generative_ai_conversation/test_config_flow.py +++ b/tests/components/google_generative_ai_conversation/test_config_flow.py @@ -22,6 +22,7 @@ from homeassistant.components.google_generative_ai_conversation.const import ( CONF_TOP_K, CONF_TOP_P, CONF_USE_GOOGLE_SEARCH_TOOL, + DEFAULT_CONVERSATION_NAME, DOMAIN, RECOMMENDED_CHAT_MODEL, RECOMMENDED_HARM_BLOCK_THRESHOLD, @@ -30,7 +31,7 @@ from homeassistant.components.google_generative_ai_conversation.const import ( RECOMMENDED_TOP_P, RECOMMENDED_USE_GOOGLE_SEARCH_TOOL, ) -from homeassistant.const import CONF_LLM_HASS_API +from homeassistant.const import CONF_API_KEY, CONF_LLM_HASS_API, CONF_NAME from homeassistant.core import HomeAssistant from homeassistant.data_entry_flow import FlowResultType @@ -110,10 +111,100 @@ async def test_form(hass: HomeAssistant) -> None: assert result2["data"] == { "api_key": "bla", } - assert result2["options"] == RECOMMENDED_OPTIONS + assert result2["options"] == {} + assert result2["subentries"] == [ + { + "subentry_type": "conversation", + "data": RECOMMENDED_OPTIONS, + "title": DEFAULT_CONVERSATION_NAME, + "unique_id": None, + } + ] assert len(mock_setup_entry.mock_calls) == 1 +async def test_duplicate_entry(hass: HomeAssistant) -> None: + """Test we get the form.""" + MockConfigEntry( + domain=DOMAIN, + data={CONF_API_KEY: "bla"}, + ).add_to_hass(hass) + + result = await hass.config_entries.flow.async_init( + DOMAIN, context={"source": config_entries.SOURCE_USER} + ) + assert result["type"] is FlowResultType.FORM + assert not result["errors"] + + result = await hass.config_entries.flow.async_configure( + result["flow_id"], + { + CONF_API_KEY: "bla", + }, + ) + + assert result["type"] is FlowResultType.ABORT + assert result["reason"] == "already_configured" + + +async def test_creating_conversation_subentry( + hass: HomeAssistant, + mock_init_component: None, + mock_config_entry: MockConfigEntry, +) -> None: + """Test creating a conversation subentry.""" + with patch( + "google.genai.models.AsyncModels.list", + return_value=get_models_pager(), + ): + result = await hass.config_entries.subentries.async_init( + (mock_config_entry.entry_id, "conversation"), + context={"source": config_entries.SOURCE_USER}, + ) + + assert result["type"] is FlowResultType.FORM + assert result["step_id"] == "set_options" + assert not result["errors"] + + with patch( + "google.genai.models.AsyncModels.list", + return_value=get_models_pager(), + ): + result2 = await hass.config_entries.subentries.async_configure( + result["flow_id"], + {CONF_NAME: "Mock name", **RECOMMENDED_OPTIONS}, + ) + await hass.async_block_till_done() + + assert result2["type"] is FlowResultType.CREATE_ENTRY + assert result2["title"] == "Mock name" + + processed_options = RECOMMENDED_OPTIONS.copy() + processed_options[CONF_PROMPT] = processed_options[CONF_PROMPT].strip() + + assert result2["data"] == processed_options + + +async def test_creating_conversation_subentry_not_loaded( + hass: HomeAssistant, + mock_init_component: None, + mock_config_entry: MockConfigEntry, +) -> None: + """Test creating a conversation subentry.""" + await hass.config_entries.async_unload(mock_config_entry.entry_id) + with patch( + "google.genai.models.AsyncModels.list", + return_value=get_models_pager(), + ): + result = await hass.config_entries.subentries.async_init( + (mock_config_entry.entry_id, "conversation"), + context={"source": config_entries.SOURCE_USER}, + ) + + assert result["type"] is FlowResultType.ABORT + assert result["reason"] == "entry_not_loaded" + + def will_options_be_rendered_again(current_options, new_options) -> bool: """Determine if options will be rendered again.""" return current_options.get(CONF_RECOMMENDED) != new_options.get(CONF_RECOMMENDED) @@ -283,7 +374,7 @@ def will_options_be_rendered_again(current_options, new_options) -> bool: ], ) @pytest.mark.usefixtures("mock_init_component") -async def test_options_switching( +async def test_subentry_options_switching( hass: HomeAssistant, mock_config_entry: MockConfigEntry, current_options, @@ -292,17 +383,18 @@ async def test_options_switching( errors, ) -> None: """Test the options form.""" + subentry = next(iter(mock_config_entry.subentries.values())) with patch("google.genai.models.AsyncModels.get"): - hass.config_entries.async_update_entry( - mock_config_entry, options=current_options + hass.config_entries.async_update_subentry( + mock_config_entry, subentry, data=current_options ) await hass.async_block_till_done() with patch( "google.genai.models.AsyncModels.list", return_value=get_models_pager(), ): - options_flow = await hass.config_entries.options.async_init( - mock_config_entry.entry_id + options_flow = await mock_config_entry.start_subentry_reconfigure_flow( + hass, subentry.subentry_id ) if will_options_be_rendered_again(current_options, new_options): retry_options = { @@ -313,7 +405,7 @@ async def test_options_switching( "google.genai.models.AsyncModels.list", return_value=get_models_pager(), ): - options_flow = await hass.config_entries.options.async_configure( + options_flow = await hass.config_entries.subentries.async_configure( options_flow["flow_id"], retry_options, ) @@ -321,14 +413,15 @@ async def test_options_switching( "google.genai.models.AsyncModels.list", return_value=get_models_pager(), ): - options = await hass.config_entries.options.async_configure( + options = await hass.config_entries.subentries.async_configure( options_flow["flow_id"], new_options, ) - await hass.async_block_till_done() + await hass.async_block_till_done() if errors is None: - assert options["type"] is FlowResultType.CREATE_ENTRY - assert options["data"] == expected_options + assert options["type"] is FlowResultType.ABORT + assert options["reason"] == "reconfigure_successful" + assert subentry.data == expected_options else: assert options["type"] is FlowResultType.FORM @@ -375,7 +468,10 @@ async def test_reauth_flow(hass: HomeAssistant) -> None: """Test the reauth flow.""" hass.config.components.add("google_generative_ai_conversation") mock_config_entry = MockConfigEntry( - domain=DOMAIN, state=config_entries.ConfigEntryState.LOADED, title="Gemini" + domain=DOMAIN, + state=config_entries.ConfigEntryState.LOADED, + title="Gemini", + version=2, ) mock_config_entry.add_to_hass(hass) mock_config_entry.async_start_reauth(hass) diff --git a/tests/components/google_generative_ai_conversation/test_conversation.py b/tests/components/google_generative_ai_conversation/test_conversation.py index 92aa6f08d42..ff9694257f9 100644 --- a/tests/components/google_generative_ai_conversation/test_conversation.py +++ b/tests/components/google_generative_ai_conversation/test_conversation.py @@ -64,7 +64,7 @@ async def test_error_handling( "hello", None, Context(), - agent_id="conversation.google_generative_ai_conversation", + agent_id="conversation.google_ai_conversation", ) assert result.response.response_type == intent.IntentResponseType.ERROR, result assert result.response.error_code == "unknown", result @@ -82,7 +82,7 @@ async def test_function_call( mock_send_message_stream: AsyncMock, ) -> None: """Test function calling.""" - agent_id = "conversation.google_generative_ai_conversation" + agent_id = "conversation.google_ai_conversation" context = Context() messages = [ @@ -212,7 +212,7 @@ async def test_google_search_tool_is_sent( mock_send_message_stream: AsyncMock, ) -> None: """Test if the Google Search tool is sent to the model.""" - agent_id = "conversation.google_generative_ai_conversation" + agent_id = "conversation.google_ai_conversation" context = Context() messages = [ @@ -278,7 +278,7 @@ async def test_blocked_response( mock_send_message_stream: AsyncMock, ) -> None: """Test blocked response.""" - agent_id = "conversation.google_generative_ai_conversation" + agent_id = "conversation.google_ai_conversation" context = Context() messages = [ @@ -328,7 +328,7 @@ async def test_empty_response( ) -> None: """Test empty response.""" - agent_id = "conversation.google_generative_ai_conversation" + agent_id = "conversation.google_ai_conversation" context = Context() messages = [ @@ -371,7 +371,7 @@ async def test_none_response( mock_send_message_stream: AsyncMock, ) -> None: """Test None response.""" - agent_id = "conversation.google_generative_ai_conversation" + agent_id = "conversation.google_ai_conversation" context = Context() messages = [ @@ -403,10 +403,12 @@ async def test_converse_error( hass: HomeAssistant, mock_config_entry: MockConfigEntry ) -> None: """Test handling ChatLog raising ConverseError.""" + subentry = next(iter(mock_config_entry.subentries.values())) with patch("google.genai.models.AsyncModels.get"): - hass.config_entries.async_update_entry( + hass.config_entries.async_update_subentry( mock_config_entry, - options={**mock_config_entry.options, CONF_LLM_HASS_API: "invalid_llm_api"}, + next(iter(mock_config_entry.subentries.values())), + data={**subentry.data, CONF_LLM_HASS_API: "invalid_llm_api"}, ) await hass.async_block_till_done() @@ -415,7 +417,7 @@ async def test_converse_error( "hello", None, Context(), - agent_id="conversation.google_generative_ai_conversation", + agent_id="conversation.google_ai_conversation", ) assert result.response.response_type == intent.IntentResponseType.ERROR, result @@ -593,7 +595,7 @@ async def test_empty_content_in_chat_history( mock_send_message_stream: AsyncMock, ) -> None: """Tests that in case of an empty entry in the chat history the google API will receive an injected space sign instead.""" - agent_id = "conversation.google_generative_ai_conversation" + agent_id = "conversation.google_ai_conversation" context = Context() messages = [ @@ -648,7 +650,7 @@ async def test_history_always_user_first_turn( ) -> None: """Test that the user is always first in the chat history.""" - agent_id = "conversation.google_generative_ai_conversation" + agent_id = "conversation.google_ai_conversation" context = Context() messages = [ @@ -674,7 +676,7 @@ async def test_history_always_user_first_turn( mock_chat_log.async_add_assistant_content_without_tools( conversation.AssistantContent( - agent_id="conversation.google_generative_ai_conversation", + agent_id="conversation.google_ai_conversation", content="Garage door left open, do you want to close it?", ) ) diff --git a/tests/components/google_generative_ai_conversation/test_init.py b/tests/components/google_generative_ai_conversation/test_init.py index 6cc0bdd5f44..dc42232fa65 100644 --- a/tests/components/google_generative_ai_conversation/test_init.py +++ b/tests/components/google_generative_ai_conversation/test_init.py @@ -7,9 +7,12 @@ import pytest from requests.exceptions import Timeout from syrupy.assertion import SnapshotAssertion +from homeassistant.components.google_generative_ai_conversation.const import DOMAIN from homeassistant.config_entries import ConfigEntryState +from homeassistant.const import CONF_API_KEY from homeassistant.core import HomeAssistant from homeassistant.exceptions import HomeAssistantError +from homeassistant.helpers import device_registry as dr, entity_registry as er from . import API_ERROR_500, CLIENT_ERROR_API_KEY_INVALID @@ -387,3 +390,214 @@ async def test_load_entry_with_unloaded_entries( "text": stubbed_generated_content, } assert [tuple(mock_call) for mock_call in mock_generate.mock_calls] == snapshot + + +async def test_migration_from_v1_to_v2( + hass: HomeAssistant, + device_registry: dr.DeviceRegistry, + entity_registry: er.EntityRegistry, +) -> None: + """Test migration from version 1 to version 2.""" + # Create a v1 config entry with conversation options and an entity + options = { + "recommended": True, + "llm_hass_api": ["assist"], + "prompt": "You are a helpful assistant", + "chat_model": "models/gemini-2.0-flash", + } + mock_config_entry = MockConfigEntry( + domain=DOMAIN, + data={CONF_API_KEY: "1234"}, + options=options, + version=1, + title="Google Generative AI", + ) + mock_config_entry.add_to_hass(hass) + mock_config_entry_2 = MockConfigEntry( + domain=DOMAIN, + data={CONF_API_KEY: "1234"}, + options=options, + version=1, + title="Google Generative AI 2", + ) + mock_config_entry_2.add_to_hass(hass) + + device_1 = device_registry.async_get_or_create( + config_entry_id=mock_config_entry.entry_id, + identifiers={(DOMAIN, mock_config_entry.entry_id)}, + name=mock_config_entry.title, + manufacturer="Google", + model="Generative AI", + entry_type=dr.DeviceEntryType.SERVICE, + ) + entity_registry.async_get_or_create( + "conversation", + DOMAIN, + mock_config_entry.entry_id, + config_entry=mock_config_entry, + device_id=device_1.id, + suggested_object_id="google_generative_ai_conversation", + ) + + device_2 = device_registry.async_get_or_create( + config_entry_id=mock_config_entry_2.entry_id, + identifiers={(DOMAIN, mock_config_entry_2.entry_id)}, + name=mock_config_entry_2.title, + manufacturer="Google", + model="Generative AI", + entry_type=dr.DeviceEntryType.SERVICE, + ) + entity_registry.async_get_or_create( + "conversation", + DOMAIN, + mock_config_entry_2.entry_id, + config_entry=mock_config_entry_2, + device_id=device_2.id, + suggested_object_id="google_generative_ai_conversation_2", + ) + + # Run migration + with patch( + "homeassistant.components.google_generative_ai_conversation.async_setup_entry", + return_value=True, + ): + await hass.config_entries.async_setup(mock_config_entry.entry_id) + await hass.async_block_till_done() + + entries = hass.config_entries.async_entries(DOMAIN) + assert len(entries) == 1 + entry = entries[0] + assert entry.version == 2 + assert not entry.options + assert len(entry.subentries) == 2 + for subentry in entry.subentries.values(): + assert subentry.subentry_type == "conversation" + assert subentry.data == options + assert "Google Generative AI" in subentry.title + + subentry = list(entry.subentries.values())[0] + + entity = entity_registry.async_get("conversation.google_generative_ai_conversation") + assert entity.unique_id == subentry.subentry_id + assert entity.config_subentry_id == subentry.subentry_id + assert entity.config_entry_id == entry.entry_id + + assert not device_registry.async_get_device( + identifiers={(DOMAIN, mock_config_entry.entry_id)} + ) + assert ( + device := device_registry.async_get_device( + identifiers={(DOMAIN, subentry.subentry_id)} + ) + ) + assert device.identifiers == {(DOMAIN, subentry.subentry_id)} + assert device.id == device_1.id + + subentry = list(entry.subentries.values())[1] + + entity = entity_registry.async_get( + "conversation.google_generative_ai_conversation_2" + ) + assert entity.unique_id == subentry.subentry_id + assert entity.config_subentry_id == subentry.subentry_id + assert entity.config_entry_id == entry.entry_id + assert not device_registry.async_get_device( + identifiers={(DOMAIN, mock_config_entry_2.entry_id)} + ) + assert ( + device := device_registry.async_get_device( + identifiers={(DOMAIN, subentry.subentry_id)} + ) + ) + assert device.identifiers == {(DOMAIN, subentry.subentry_id)} + assert device.id == device_2.id + + +async def test_migration_from_v1_to_v2_with_multiple_keys( + hass: HomeAssistant, + device_registry: dr.DeviceRegistry, + entity_registry: er.EntityRegistry, +) -> None: + """Test migration from version 1 to version 2.""" + # Create a v1 config entry with conversation options and an entity + options = { + "recommended": True, + "llm_hass_api": ["assist"], + "prompt": "You are a helpful assistant", + "chat_model": "models/gemini-2.0-flash", + } + mock_config_entry = MockConfigEntry( + domain=DOMAIN, + data={CONF_API_KEY: "1234"}, + options=options, + version=1, + title="Google Generative AI", + ) + mock_config_entry.add_to_hass(hass) + mock_config_entry_2 = MockConfigEntry( + domain=DOMAIN, + data={CONF_API_KEY: "12345"}, + options=options, + version=1, + title="Google Generative AI 2", + ) + mock_config_entry_2.add_to_hass(hass) + + device = device_registry.async_get_or_create( + config_entry_id=mock_config_entry.entry_id, + identifiers={(DOMAIN, mock_config_entry.entry_id)}, + name=mock_config_entry.title, + manufacturer="Google", + model="Generative AI", + entry_type=dr.DeviceEntryType.SERVICE, + ) + entity_registry.async_get_or_create( + "conversation", + DOMAIN, + mock_config_entry.entry_id, + config_entry=mock_config_entry, + device_id=device.id, + suggested_object_id="google_generative_ai_conversation", + ) + + device_2 = device_registry.async_get_or_create( + config_entry_id=mock_config_entry_2.entry_id, + identifiers={(DOMAIN, mock_config_entry_2.entry_id)}, + name=mock_config_entry_2.title, + manufacturer="Google", + model="Generative AI", + entry_type=dr.DeviceEntryType.SERVICE, + ) + entity_registry.async_get_or_create( + "conversation", + DOMAIN, + mock_config_entry_2.entry_id, + config_entry=mock_config_entry_2, + device_id=device_2.id, + suggested_object_id="google_generative_ai_conversation_2", + ) + + # Run migration + with patch( + "homeassistant.components.google_generative_ai_conversation.async_setup_entry", + return_value=True, + ): + await hass.config_entries.async_setup(mock_config_entry.entry_id) + await hass.async_block_till_done() + + entries = hass.config_entries.async_entries(DOMAIN) + assert len(entries) == 2 + + for entry in entries: + assert entry.version == 2 + assert not entry.options + assert len(entry.subentries) == 1 + subentry = list(entry.subentries.values())[0] + assert subentry.subentry_type == "conversation" + assert subentry.data == options + assert "Google Generative AI" in subentry.title + + dev = device_registry.async_get_device( + identifiers={(DOMAIN, list(entry.subentries.values())[0].subentry_id)} + ) + assert dev is not None diff --git a/tests/components/google_generative_ai_conversation/test_tts.py b/tests/components/google_generative_ai_conversation/test_tts.py index 5ea056307b5..4f197f0535f 100644 --- a/tests/components/google_generative_ai_conversation/test_tts.py +++ b/tests/components/google_generative_ai_conversation/test_tts.py @@ -122,7 +122,9 @@ async def mock_setup(hass: HomeAssistant, config: dict[str, Any]) -> None: async def mock_config_entry_setup(hass: HomeAssistant, config: dict[str, Any]) -> None: """Mock config entry setup.""" default_config = {tts.CONF_LANG: "en-US"} - config_entry = MockConfigEntry(domain=DOMAIN, data=default_config | config) + config_entry = MockConfigEntry( + domain=DOMAIN, data=default_config | config, version=2 + ) client_mock = Mock() client_mock.models.get = None