From 35478e316296f6efd640a42ff5abfcb85edc6342 Mon Sep 17 00:00:00 2001 From: Joost Lekkerkerker Date: Thu, 26 Jun 2025 19:44:15 +0200 Subject: [PATCH] Set Google AI model as device model (#147582) * Set Google AI model as device model * fix --- .../entity.py | 9 ++- .../google_generative_ai_conversation/tts.py | 6 +- .../snapshots/test_init.ambr | 66 +++++++++++++++++++ .../test_init.py | 14 ++++ 4 files changed, 92 insertions(+), 3 deletions(-) diff --git a/homeassistant/components/google_generative_ai_conversation/entity.py b/homeassistant/components/google_generative_ai_conversation/entity.py index 66acb6b158a..dea875212ef 100644 --- a/homeassistant/components/google_generative_ai_conversation/entity.py +++ b/homeassistant/components/google_generative_ai_conversation/entity.py @@ -301,7 +301,12 @@ async def _transform_stream( class GoogleGenerativeAILLMBaseEntity(Entity): """Google Generative AI base entity.""" - def __init__(self, entry: ConfigEntry, subentry: ConfigSubentry) -> None: + def __init__( + self, + entry: ConfigEntry, + subentry: ConfigSubentry, + default_model: str = RECOMMENDED_CHAT_MODEL, + ) -> None: """Initialize the agent.""" self.entry = entry self.subentry = subentry @@ -312,7 +317,7 @@ class GoogleGenerativeAILLMBaseEntity(Entity): identifiers={(DOMAIN, subentry.subentry_id)}, name=subentry.title, manufacturer="Google", - model="Generative AI", + model=subentry.data.get(CONF_CHAT_MODEL, default_model).split("/")[-1], entry_type=dr.DeviceEntryType.SERVICE, ) diff --git a/homeassistant/components/google_generative_ai_conversation/tts.py b/homeassistant/components/google_generative_ai_conversation/tts.py index 9bd7d547100..9bc5b0c6cb6 100644 --- a/homeassistant/components/google_generative_ai_conversation/tts.py +++ b/homeassistant/components/google_generative_ai_conversation/tts.py @@ -15,7 +15,7 @@ from homeassistant.components.tts import ( TtsAudioType, Voice, ) -from homeassistant.config_entries import ConfigEntry +from homeassistant.config_entries import ConfigEntry, ConfigSubentry from homeassistant.core import HomeAssistant, callback from homeassistant.exceptions import HomeAssistantError from homeassistant.helpers.entity_platform import AddConfigEntryEntitiesCallback @@ -114,6 +114,10 @@ class GoogleGenerativeAITextToSpeechEntity( ) ] + def __init__(self, config_entry: ConfigEntry, subentry: ConfigSubentry) -> None: + """Initialize the TTS entity.""" + super().__init__(config_entry, subentry, RECOMMENDED_TTS_MODEL) + @callback def async_get_supported_voices(self, language: str) -> list[Voice]: """Return a list of supported voices for a language.""" diff --git a/tests/components/google_generative_ai_conversation/snapshots/test_init.ambr b/tests/components/google_generative_ai_conversation/snapshots/test_init.ambr index f89871ff131..5722713bc56 100644 --- a/tests/components/google_generative_ai_conversation/snapshots/test_init.ambr +++ b/tests/components/google_generative_ai_conversation/snapshots/test_init.ambr @@ -1,4 +1,70 @@ # serializer version: 1 +# name: test_devices + list([ + DeviceRegistryEntrySnapshot({ + 'area_id': None, + 'config_entries': , + 'config_entries_subentries': , + 'configuration_url': None, + 'connections': set({ + }), + 'disabled_by': None, + 'entry_type': , + 'hw_version': None, + 'id': , + 'identifiers': set({ + tuple( + 'google_generative_ai_conversation', + 'ulid-conversation', + ), + }), + 'is_new': False, + 'labels': set({ + }), + 'manufacturer': 'Google', + 'model': 'gemini-2.5-flash', + 'model_id': None, + 'name': 'Google AI Conversation', + 'name_by_user': None, + 'primary_config_entry': , + 'serial_number': None, + 'suggested_area': None, + 'sw_version': None, + 'via_device_id': None, + }), + DeviceRegistryEntrySnapshot({ + 'area_id': None, + 'config_entries': , + 'config_entries_subentries': , + 'configuration_url': None, + 'connections': set({ + }), + 'disabled_by': None, + 'entry_type': , + 'hw_version': None, + 'id': , + 'identifiers': set({ + tuple( + 'google_generative_ai_conversation', + 'ulid-tts', + ), + }), + 'is_new': False, + 'labels': set({ + }), + 'manufacturer': 'Google', + 'model': 'gemini-2.5-flash-preview-tts', + 'model_id': None, + 'name': 'Google AI TTS', + 'name_by_user': None, + 'primary_config_entry': , + 'serial_number': None, + 'suggested_area': None, + 'sw_version': None, + 'via_device_id': None, + }), + ]) +# --- # name: test_generate_content_file_processing_succeeds list([ tuple( diff --git a/tests/components/google_generative_ai_conversation/test_init.py b/tests/components/google_generative_ai_conversation/test_init.py index 46a2d634b81..85d6c70b658 100644 --- a/tests/components/google_generative_ai_conversation/test_init.py +++ b/tests/components/google_generative_ai_conversation/test_init.py @@ -762,3 +762,17 @@ async def test_migration_from_v1_to_v2_with_same_keys( ) assert device.identifiers == {(DOMAIN, subentry.subentry_id)} assert device.id == device_2.id + + +async def test_devices( + hass: HomeAssistant, + mock_config_entry: MockConfigEntry, + mock_init_component, + device_registry: dr.DeviceRegistry, + snapshot: SnapshotAssertion, +) -> None: + """Assert that devices are created correctly.""" + devices = dr.async_entries_for_config_entry( + device_registry, mock_config_entry.entry_id + ) + assert devices == snapshot