Fixes in Google AI TTS (#147501)

* Fix Google AI not using correct config options after subentries migration

* Fixes in Google AI TTS

* Fix tests by @IvanLH

* Change type name.

---------

Co-authored-by: Paulus Schoutsen <balloob@gmail.com>
Co-authored-by: Paulus Schoutsen <paulus@home-assistant.io>
This commit is contained in:
tronikos 2025-06-25 15:12:23 -07:00 committed by GitHub
parent 345ec97dd5
commit f0a78aadbe
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
10 changed files with 412 additions and 363 deletions

View File

@ -5,6 +5,7 @@ from __future__ import annotations
import asyncio
import mimetypes
from pathlib import Path
from types import MappingProxyType
from google.genai import Client
from google.genai.errors import APIError, ClientError
@ -36,10 +37,12 @@ from homeassistant.helpers.typing import ConfigType
from .const import (
CONF_PROMPT,
DEFAULT_TTS_NAME,
DOMAIN,
FILE_POLLING_INTERVAL_SECONDS,
LOGGER,
RECOMMENDED_CHAT_MODEL,
RECOMMENDED_TTS_OPTIONS,
TIMEOUT_MILLIS,
)
@ -242,6 +245,16 @@ async def async_migrate_integration(hass: HomeAssistant) -> None:
parent_entry = api_keys_entries[entry.data[CONF_API_KEY]]
hass.config_entries.async_add_subentry(parent_entry, subentry)
if use_existing:
hass.config_entries.async_add_subentry(
parent_entry,
ConfigSubentry(
data=MappingProxyType(RECOMMENDED_TTS_OPTIONS),
subentry_type="tts",
title=DEFAULT_TTS_NAME,
unique_id=None,
),
)
conversation_entity = entity_registry.async_get_entity_id(
"conversation",
DOMAIN,

View File

@ -47,13 +47,17 @@ from .const import (
CONF_TOP_P,
CONF_USE_GOOGLE_SEARCH_TOOL,
DEFAULT_CONVERSATION_NAME,
DEFAULT_TTS_NAME,
DOMAIN,
RECOMMENDED_CHAT_MODEL,
RECOMMENDED_CONVERSATION_OPTIONS,
RECOMMENDED_HARM_BLOCK_THRESHOLD,
RECOMMENDED_MAX_TOKENS,
RECOMMENDED_TEMPERATURE,
RECOMMENDED_TOP_K,
RECOMMENDED_TOP_P,
RECOMMENDED_TTS_MODEL,
RECOMMENDED_TTS_OPTIONS,
RECOMMENDED_USE_GOOGLE_SEARCH_TOOL,
TIMEOUT_MILLIS,
)
@ -66,12 +70,6 @@ STEP_API_DATA_SCHEMA = vol.Schema(
}
)
RECOMMENDED_OPTIONS = {
CONF_RECOMMENDED: True,
CONF_LLM_HASS_API: [llm.LLM_API_ASSIST],
CONF_PROMPT: llm.DEFAULT_INSTRUCTIONS_PROMPT,
}
async def validate_input(data: dict[str, Any]) -> None:
"""Validate the user input allows us to connect.
@ -123,10 +121,16 @@ class GoogleGenerativeAIConfigFlow(ConfigFlow, domain=DOMAIN):
subentries=[
{
"subentry_type": "conversation",
"data": RECOMMENDED_OPTIONS,
"data": RECOMMENDED_CONVERSATION_OPTIONS,
"title": DEFAULT_CONVERSATION_NAME,
"unique_id": None,
}
},
{
"subentry_type": "tts",
"data": RECOMMENDED_TTS_OPTIONS,
"title": DEFAULT_TTS_NAME,
"unique_id": None,
},
],
)
return self.async_show_form(
@ -172,10 +176,13 @@ class GoogleGenerativeAIConfigFlow(ConfigFlow, domain=DOMAIN):
cls, config_entry: ConfigEntry
) -> dict[str, type[ConfigSubentryFlow]]:
"""Return subentries supported by this integration."""
return {"conversation": ConversationSubentryFlowHandler}
return {
"conversation": LLMSubentryFlowHandler,
"tts": LLMSubentryFlowHandler,
}
class ConversationSubentryFlowHandler(ConfigSubentryFlow):
class LLMSubentryFlowHandler(ConfigSubentryFlow):
"""Flow for managing conversation subentries."""
last_rendered_recommended = False
@ -202,7 +209,11 @@ class ConversationSubentryFlowHandler(ConfigSubentryFlow):
if user_input is None:
if self._is_new:
options = RECOMMENDED_OPTIONS.copy()
options: dict[str, Any]
if self._subentry_type == "tts":
options = RECOMMENDED_TTS_OPTIONS.copy()
else:
options = RECOMMENDED_CONVERSATION_OPTIONS.copy()
else:
# If this is a reconfiguration, we need to copy the existing options
# so that we can show the current values in the form.
@ -216,7 +227,7 @@ class ConversationSubentryFlowHandler(ConfigSubentryFlow):
if user_input[CONF_RECOMMENDED] == self.last_rendered_recommended:
if not user_input.get(CONF_LLM_HASS_API):
user_input.pop(CONF_LLM_HASS_API, None)
# Don't allow to save options that enable the Google Seearch tool with an Assist API
# Don't allow to save options that enable the Google Search tool with an Assist API
if not (
user_input.get(CONF_LLM_HASS_API)
and user_input.get(CONF_USE_GOOGLE_SEARCH_TOOL, False) is True
@ -240,7 +251,7 @@ class ConversationSubentryFlowHandler(ConfigSubentryFlow):
options = user_input
schema = await google_generative_ai_config_option_schema(
self.hass, self._is_new, options, self._genai_client
self.hass, self._is_new, self._subentry_type, options, self._genai_client
)
return self.async_show_form(
step_id="set_options", data_schema=vol.Schema(schema), errors=errors
@ -253,6 +264,7 @@ class ConversationSubentryFlowHandler(ConfigSubentryFlow):
async def google_generative_ai_config_option_schema(
hass: HomeAssistant,
is_new: bool,
subentry_type: str,
options: Mapping[str, Any],
genai_client: genai.Client,
) -> dict:
@ -270,26 +282,39 @@ async def google_generative_ai_config_option_schema(
suggested_llm_apis = [suggested_llm_apis]
if is_new:
if CONF_NAME in options:
default_name = options[CONF_NAME]
elif subentry_type == "tts":
default_name = DEFAULT_TTS_NAME
else:
default_name = DEFAULT_CONVERSATION_NAME
schema: dict[vol.Required | vol.Optional, Any] = {
vol.Required(CONF_NAME, default=DEFAULT_CONVERSATION_NAME): str,
vol.Required(CONF_NAME, default=default_name): str,
}
else:
schema = {}
if subentry_type == "conversation":
schema.update(
{
vol.Optional(
CONF_PROMPT,
description={
"suggested_value": options.get(
CONF_PROMPT, llm.DEFAULT_INSTRUCTIONS_PROMPT
)
},
): TemplateSelector(),
vol.Optional(
CONF_LLM_HASS_API,
description={"suggested_value": suggested_llm_apis},
): SelectSelector(
SelectSelectorConfig(options=hass_apis, multiple=True)
),
}
)
schema.update(
{
vol.Optional(
CONF_PROMPT,
description={
"suggested_value": options.get(
CONF_PROMPT, llm.DEFAULT_INSTRUCTIONS_PROMPT
)
},
): TemplateSelector(),
vol.Optional(
CONF_LLM_HASS_API,
description={"suggested_value": suggested_llm_apis},
): SelectSelector(SelectSelectorConfig(options=hass_apis, multiple=True)),
vol.Required(
CONF_RECOMMENDED, default=options.get(CONF_RECOMMENDED, False)
): bool,
@ -310,7 +335,7 @@ async def google_generative_ai_config_option_schema(
if (
api_model.display_name
and api_model.name
and "tts" not in api_model.name
and ("tts" in api_model.name) == (subentry_type == "tts")
and "vision" not in api_model.name
and api_model.supported_actions
and "generateContent" in api_model.supported_actions
@ -341,12 +366,17 @@ async def google_generative_ai_config_option_schema(
)
)
if subentry_type == "tts":
default_model = RECOMMENDED_TTS_MODEL
else:
default_model = RECOMMENDED_CHAT_MODEL
schema.update(
{
vol.Optional(
CONF_CHAT_MODEL,
description={"suggested_value": options.get(CONF_CHAT_MODEL)},
default=RECOMMENDED_CHAT_MODEL,
default=default_model,
): SelectSelector(
SelectSelectorConfig(mode=SelectSelectorMode.DROPDOWN, options=models)
),
@ -396,13 +426,18 @@ async def google_generative_ai_config_option_schema(
},
default=RECOMMENDED_HARM_BLOCK_THRESHOLD,
): harm_block_thresholds_selector,
vol.Optional(
CONF_USE_GOOGLE_SEARCH_TOOL,
description={
"suggested_value": options.get(CONF_USE_GOOGLE_SEARCH_TOOL),
},
default=RECOMMENDED_USE_GOOGLE_SEARCH_TOOL,
): bool,
}
)
if subentry_type == "conversation":
schema.update(
{
vol.Optional(
CONF_USE_GOOGLE_SEARCH_TOOL,
description={
"suggested_value": options.get(CONF_USE_GOOGLE_SEARCH_TOOL),
},
default=RECOMMENDED_USE_GOOGLE_SEARCH_TOOL,
): bool,
}
)
return schema

View File

@ -2,17 +2,20 @@
import logging
from homeassistant.const import CONF_LLM_HASS_API
from homeassistant.helpers import llm
DOMAIN = "google_generative_ai_conversation"
LOGGER = logging.getLogger(__package__)
CONF_PROMPT = "prompt"
DEFAULT_CONVERSATION_NAME = "Google AI Conversation"
DEFAULT_TTS_NAME = "Google AI TTS"
ATTR_MODEL = "model"
CONF_RECOMMENDED = "recommended"
CONF_CHAT_MODEL = "chat_model"
RECOMMENDED_CHAT_MODEL = "models/gemini-2.5-flash"
RECOMMENDED_TTS_MODEL = "gemini-2.5-flash-preview-tts"
RECOMMENDED_TTS_MODEL = "models/gemini-2.5-flash-preview-tts"
CONF_TEMPERATURE = "temperature"
RECOMMENDED_TEMPERATURE = 1.0
CONF_TOP_P = "top_p"
@ -31,3 +34,12 @@ RECOMMENDED_USE_GOOGLE_SEARCH_TOOL = False
TIMEOUT_MILLIS = 10000
FILE_POLLING_INTERVAL_SECONDS = 0.05
RECOMMENDED_CONVERSATION_OPTIONS = {
CONF_PROMPT: llm.DEFAULT_INSTRUCTIONS_PROMPT,
CONF_LLM_HASS_API: [llm.LLM_API_ASSIST],
CONF_RECOMMENDED: True,
}
RECOMMENDED_TTS_OPTIONS = {
CONF_RECOMMENDED: True,
}

View File

@ -29,7 +29,6 @@
"reconfigure": "Reconfigure conversation agent"
},
"entry_type": "Conversation agent",
"step": {
"set_options": {
"data": {
@ -61,6 +60,34 @@
"error": {
"invalid_google_search_option": "Google Search can only be enabled if nothing is selected in the \"Control Home Assistant\" setting."
}
},
"tts": {
"initiate_flow": {
"user": "Add Text-to-Speech service",
"reconfigure": "Reconfigure Text-to-Speech service"
},
"entry_type": "Text-to-Speech",
"step": {
"set_options": {
"data": {
"name": "[%key:common::config_flow::data::name%]",
"recommended": "[%key:component::google_generative_ai_conversation::config_subentries::conversation::step::set_options::data::recommended%]",
"chat_model": "[%key:common::generic::model%]",
"temperature": "[%key:component::google_generative_ai_conversation::config_subentries::conversation::step::set_options::data::temperature%]",
"top_p": "[%key:component::google_generative_ai_conversation::config_subentries::conversation::step::set_options::data::top_p%]",
"top_k": "[%key:component::google_generative_ai_conversation::config_subentries::conversation::step::set_options::data::top_k%]",
"max_tokens": "[%key:component::google_generative_ai_conversation::config_subentries::conversation::step::set_options::data::max_tokens%]",
"harassment_block_threshold": "[%key:component::google_generative_ai_conversation::config_subentries::conversation::step::set_options::data::harassment_block_threshold%]",
"hate_block_threshold": "[%key:component::google_generative_ai_conversation::config_subentries::conversation::step::set_options::data::hate_block_threshold%]",
"sexual_block_threshold": "[%key:component::google_generative_ai_conversation::config_subentries::conversation::step::set_options::data::sexual_block_threshold%]",
"dangerous_block_threshold": "[%key:component::google_generative_ai_conversation::config_subentries::conversation::step::set_options::data::dangerous_block_threshold%]"
}
}
},
"abort": {
"entry_not_loaded": "[%key:component::google_generative_ai_conversation::config_subentries::conversation::abort::entry_not_loaded%]",
"reconfigure_successful": "[%key:common::config_flow::abort::reconfigure_successful%]"
}
}
},
"services": {

View File

@ -2,13 +2,15 @@
from __future__ import annotations
from collections.abc import Mapping
from contextlib import suppress
import io
import logging
from typing import Any
import wave
from google.genai import types
from google.genai.errors import APIError, ClientError
from propcache.api import cached_property
from homeassistant.components.tts import (
ATTR_VOICE,
@ -19,12 +21,10 @@ from homeassistant.components.tts import (
from homeassistant.config_entries import ConfigEntry
from homeassistant.core import HomeAssistant, callback
from homeassistant.exceptions import HomeAssistantError
from homeassistant.helpers import device_registry as dr
from homeassistant.helpers.entity_platform import AddConfigEntryEntitiesCallback
from .const import ATTR_MODEL, DOMAIN, RECOMMENDED_TTS_MODEL
_LOGGER = logging.getLogger(__name__)
from .const import CONF_CHAT_MODEL, LOGGER, RECOMMENDED_TTS_MODEL
from .entity import GoogleGenerativeAILLMBaseEntity
async def async_setup_entry(
@ -32,15 +32,23 @@ async def async_setup_entry(
config_entry: ConfigEntry,
async_add_entities: AddConfigEntryEntitiesCallback,
) -> None:
"""Set up TTS entity."""
tts_entity = GoogleGenerativeAITextToSpeechEntity(config_entry)
async_add_entities([tts_entity])
"""Set up TTS entities."""
for subentry in config_entry.subentries.values():
if subentry.subentry_type != "tts":
continue
async_add_entities(
[GoogleGenerativeAITextToSpeechEntity(config_entry, subentry)],
config_subentry_id=subentry.subentry_id,
)
class GoogleGenerativeAITextToSpeechEntity(TextToSpeechEntity):
class GoogleGenerativeAITextToSpeechEntity(
TextToSpeechEntity, GoogleGenerativeAILLMBaseEntity
):
"""Google Generative AI text-to-speech entity."""
_attr_supported_options = [ATTR_VOICE, ATTR_MODEL]
_attr_supported_options = [ATTR_VOICE]
# See https://ai.google.dev/gemini-api/docs/speech-generation#languages
_attr_supported_languages = [
"ar-EG",
@ -68,6 +76,8 @@ class GoogleGenerativeAITextToSpeechEntity(TextToSpeechEntity):
"uk-UA",
"vi-VN",
]
# Unused, but required by base class.
# The Gemini TTS models detect the input language automatically.
_attr_default_language = "en-US"
# See https://ai.google.dev/gemini-api/docs/speech-generation#voices
_supported_voices = [
@ -106,53 +116,41 @@ class GoogleGenerativeAITextToSpeechEntity(TextToSpeechEntity):
)
]
def __init__(self, entry: ConfigEntry) -> None:
"""Initialize Google Generative AI Conversation speech entity."""
self.entry = entry
self._attr_name = "Google Generative AI TTS"
self._attr_unique_id = f"{entry.entry_id}_tts"
self._attr_device_info = dr.DeviceInfo(
identifiers={(DOMAIN, entry.entry_id)},
manufacturer="Google",
model="Generative AI",
entry_type=dr.DeviceEntryType.SERVICE,
)
self._genai_client = entry.runtime_data
self._default_voice_id = self._supported_voices[0].voice_id
@callback
def async_get_supported_voices(self, language: str) -> list[Voice] | None:
def async_get_supported_voices(self, language: str) -> list[Voice]:
"""Return a list of supported voices for a language."""
return self._supported_voices
@cached_property
def default_options(self) -> Mapping[str, Any]:
"""Return a mapping with the default options."""
return {
ATTR_VOICE: self._supported_voices[0].voice_id,
}
async def async_get_tts_audio(
self, message: str, language: str, options: dict[str, Any]
) -> TtsAudioType:
"""Load tts audio file from the engine."""
try:
response = self._genai_client.models.generate_content(
model=options.get(ATTR_MODEL, RECOMMENDED_TTS_MODEL),
contents=message,
config=types.GenerateContentConfig(
response_modalities=["AUDIO"],
speech_config=types.SpeechConfig(
voice_config=types.VoiceConfig(
prebuilt_voice_config=types.PrebuiltVoiceConfig(
voice_name=options.get(
ATTR_VOICE, self._default_voice_id
)
)
)
),
),
config = self.create_generate_content_config()
config.response_modalities = ["AUDIO"]
config.speech_config = types.SpeechConfig(
voice_config=types.VoiceConfig(
prebuilt_voice_config=types.PrebuiltVoiceConfig(
voice_name=options[ATTR_VOICE]
)
)
)
try:
response = await self._genai_client.aio.models.generate_content(
model=self.subentry.data.get(CONF_CHAT_MODEL, RECOMMENDED_TTS_MODEL),
contents=message,
config=config,
)
data = response.candidates[0].content.parts[0].inline_data.data
mime_type = response.candidates[0].content.parts[0].inline_data.mime_type
except Exception as exc:
_LOGGER.warning(
"Error during processing of TTS request %s", exc, exc_info=True
)
except (APIError, ClientError, ValueError) as exc:
LOGGER.error("Error during TTS: %s", exc, exc_info=True)
raise HomeAssistantError(exc) from exc
return "wav", self._convert_to_wav(data, mime_type)
@ -192,7 +190,7 @@ class GoogleGenerativeAITextToSpeechEntity(TextToSpeechEntity):
"""
if not mime_type.startswith("audio/L"):
_LOGGER.warning("Received unexpected MIME type %s", mime_type)
LOGGER.warning("Received unexpected MIME type %s", mime_type)
raise HomeAssistantError(f"Unsupported audio MIME type: {mime_type}")
bits_per_sample = 16

View File

@ -3420,6 +3420,11 @@ class ConfigSubentryFlow(
"""Return config entry id."""
return self.handler[0]
@property
def _subentry_type(self) -> str:
"""Return type of subentry we are editing/creating."""
return self.handler[1]
@callback
def _get_entry(self) -> ConfigEntry:
"""Return the config entry linked to the current context."""

View File

@ -8,6 +8,7 @@ import pytest
from homeassistant.components.google_generative_ai_conversation.const import (
CONF_USE_GOOGLE_SEARCH_TOOL,
DEFAULT_CONVERSATION_NAME,
DEFAULT_TTS_NAME,
)
from homeassistant.config_entries import ConfigEntry
from homeassistant.const import CONF_LLM_HASS_API
@ -34,7 +35,13 @@ def mock_config_entry(hass: HomeAssistant) -> MockConfigEntry:
"subentry_type": "conversation",
"title": DEFAULT_CONVERSATION_NAME,
"unique_id": None,
}
},
{
"data": {},
"subentry_type": "tts",
"title": DEFAULT_TTS_NAME,
"unique_id": None,
},
],
)
entry.runtime_data = Mock()

View File

@ -6,9 +6,6 @@ import pytest
from requests.exceptions import Timeout
from homeassistant import config_entries
from homeassistant.components.google_generative_ai_conversation.config_flow import (
RECOMMENDED_OPTIONS,
)
from homeassistant.components.google_generative_ai_conversation.const import (
CONF_CHAT_MODEL,
CONF_DANGEROUS_BLOCK_THRESHOLD,
@ -23,12 +20,15 @@ from homeassistant.components.google_generative_ai_conversation.const import (
CONF_TOP_P,
CONF_USE_GOOGLE_SEARCH_TOOL,
DEFAULT_CONVERSATION_NAME,
DEFAULT_TTS_NAME,
DOMAIN,
RECOMMENDED_CHAT_MODEL,
RECOMMENDED_CONVERSATION_OPTIONS,
RECOMMENDED_HARM_BLOCK_THRESHOLD,
RECOMMENDED_MAX_TOKENS,
RECOMMENDED_TOP_K,
RECOMMENDED_TOP_P,
RECOMMENDED_TTS_OPTIONS,
RECOMMENDED_USE_GOOGLE_SEARCH_TOOL,
)
from homeassistant.const import CONF_API_KEY, CONF_LLM_HASS_API, CONF_NAME
@ -115,10 +115,16 @@ async def test_form(hass: HomeAssistant) -> None:
assert result2["subentries"] == [
{
"subentry_type": "conversation",
"data": RECOMMENDED_OPTIONS,
"data": RECOMMENDED_CONVERSATION_OPTIONS,
"title": DEFAULT_CONVERSATION_NAME,
"unique_id": None,
}
},
{
"subentry_type": "tts",
"data": RECOMMENDED_TTS_OPTIONS,
"title": DEFAULT_TTS_NAME,
"unique_id": None,
},
]
assert len(mock_setup_entry.mock_calls) == 1
@ -172,19 +178,64 @@ async def test_creating_conversation_subentry(
):
result2 = await hass.config_entries.subentries.async_configure(
result["flow_id"],
{CONF_NAME: "Mock name", **RECOMMENDED_OPTIONS},
{CONF_NAME: "Mock name", **RECOMMENDED_CONVERSATION_OPTIONS},
)
await hass.async_block_till_done()
assert result2["type"] is FlowResultType.CREATE_ENTRY
assert result2["title"] == "Mock name"
processed_options = RECOMMENDED_OPTIONS.copy()
processed_options = RECOMMENDED_CONVERSATION_OPTIONS.copy()
processed_options[CONF_PROMPT] = processed_options[CONF_PROMPT].strip()
assert result2["data"] == processed_options
async def test_creating_tts_subentry(
hass: HomeAssistant,
mock_init_component: None,
mock_config_entry: MockConfigEntry,
) -> None:
"""Test creating a TTS subentry."""
with patch(
"google.genai.models.AsyncModels.list",
return_value=get_models_pager(),
):
result = await hass.config_entries.subentries.async_init(
(mock_config_entry.entry_id, "tts"),
context={"source": config_entries.SOURCE_USER},
)
assert result["type"] is FlowResultType.FORM, result
assert result["step_id"] == "set_options"
assert not result["errors"]
old_subentries = set(mock_config_entry.subentries)
with patch(
"google.genai.models.AsyncModels.list",
return_value=get_models_pager(),
):
result2 = await hass.config_entries.subentries.async_configure(
result["flow_id"],
{CONF_NAME: "Mock TTS", **RECOMMENDED_TTS_OPTIONS},
)
await hass.async_block_till_done()
assert result2["type"] is FlowResultType.CREATE_ENTRY
assert result2["title"] == "Mock TTS"
assert result2["data"] == RECOMMENDED_TTS_OPTIONS
assert len(mock_config_entry.subentries) == 3
new_subentry_id = list(set(mock_config_entry.subentries) - old_subentries)[0]
new_subentry = mock_config_entry.subentries[new_subentry_id]
assert new_subentry.subentry_type == "tts"
assert new_subentry.data == RECOMMENDED_TTS_OPTIONS
assert new_subentry.title == "Mock TTS"
async def test_creating_conversation_subentry_not_loaded(
hass: HomeAssistant,
mock_init_component: None,

View File

@ -7,7 +7,11 @@ import pytest
from requests.exceptions import Timeout
from syrupy.assertion import SnapshotAssertion
from homeassistant.components.google_generative_ai_conversation.const import DOMAIN
from homeassistant.components.google_generative_ai_conversation.const import (
DEFAULT_TTS_NAME,
DOMAIN,
RECOMMENDED_TTS_OPTIONS,
)
from homeassistant.config_entries import ConfigEntryState
from homeassistant.const import CONF_API_KEY
from homeassistant.core import HomeAssistant
@ -469,13 +473,27 @@ async def test_migration_from_v1_to_v2(
entry = entries[0]
assert entry.version == 2
assert not entry.options
assert len(entry.subentries) == 2
for subentry in entry.subentries.values():
assert len(entry.subentries) == 3
conversation_subentries = [
subentry
for subentry in entry.subentries.values()
if subentry.subentry_type == "conversation"
]
assert len(conversation_subentries) == 2
for subentry in conversation_subentries:
assert subentry.subentry_type == "conversation"
assert subentry.data == options
assert "Google Generative AI" in subentry.title
tts_subentries = [
subentry
for subentry in entry.subentries.values()
if subentry.subentry_type == "tts"
]
assert len(tts_subentries) == 1
assert tts_subentries[0].data == RECOMMENDED_TTS_OPTIONS
assert tts_subentries[0].title == DEFAULT_TTS_NAME
subentry = list(entry.subentries.values())[0]
subentry = conversation_subentries[0]
entity = entity_registry.async_get("conversation.google_generative_ai_conversation")
assert entity.unique_id == subentry.subentry_id
@ -493,7 +511,7 @@ async def test_migration_from_v1_to_v2(
assert device.identifiers == {(DOMAIN, subentry.subentry_id)}
assert device.id == device_1.id
subentry = list(entry.subentries.values())[1]
subentry = conversation_subentries[1]
entity = entity_registry.async_get(
"conversation.google_generative_ai_conversation_2"
@ -591,11 +609,15 @@ async def test_migration_from_v1_to_v2_with_multiple_keys(
for entry in entries:
assert entry.version == 2
assert not entry.options
assert len(entry.subentries) == 1
assert len(entry.subentries) == 2
subentry = list(entry.subentries.values())[0]
assert subentry.subentry_type == "conversation"
assert subentry.data == options
assert "Google Generative AI" in subentry.title
subentry = list(entry.subentries.values())[1]
assert subentry.subentry_type == "tts"
assert subentry.data == RECOMMENDED_TTS_OPTIONS
assert subentry.title == DEFAULT_TTS_NAME
dev = device_registry.async_get_device(
identifiers={(DOMAIN, list(entry.subentries.values())[0].subentry_id)}
@ -680,13 +702,27 @@ async def test_migration_from_v1_to_v2_with_same_keys(
entry = entries[0]
assert entry.version == 2
assert not entry.options
assert len(entry.subentries) == 2
for subentry in entry.subentries.values():
assert len(entry.subentries) == 3
conversation_subentries = [
subentry
for subentry in entry.subentries.values()
if subentry.subentry_type == "conversation"
]
assert len(conversation_subentries) == 2
for subentry in conversation_subentries:
assert subentry.subentry_type == "conversation"
assert subentry.data == options
assert "Google Generative AI" in subentry.title
tts_subentries = [
subentry
for subentry in entry.subentries.values()
if subentry.subentry_type == "tts"
]
assert len(tts_subentries) == 1
assert tts_subentries[0].data == RECOMMENDED_TTS_OPTIONS
assert tts_subentries[0].title == DEFAULT_TTS_NAME
subentry = list(entry.subentries.values())[0]
subentry = conversation_subentries[0]
entity = entity_registry.async_get("conversation.google_generative_ai_conversation")
assert entity.unique_id == subentry.subentry_id
@ -704,7 +740,7 @@ async def test_migration_from_v1_to_v2_with_same_keys(
assert device.identifiers == {(DOMAIN, subentry.subentry_id)}
assert device.id == device_1.id
subentry = list(entry.subentries.values())[1]
subentry = conversation_subentries[1]
entity = entity_registry.async_get(
"conversation.google_generative_ai_conversation_2"

View File

@ -9,30 +9,37 @@ from typing import Any
from unittest.mock import AsyncMock, MagicMock, Mock, patch
from google.genai import types
from google.genai.errors import APIError
import pytest
from homeassistant.components import tts
from homeassistant.components.google_generative_ai_conversation.tts import (
ATTR_MODEL,
from homeassistant.components.google_generative_ai_conversation.const import (
CONF_CHAT_MODEL,
DOMAIN,
RECOMMENDED_TTS_MODEL,
RECOMMENDED_HARM_BLOCK_THRESHOLD,
RECOMMENDED_MAX_TOKENS,
RECOMMENDED_TEMPERATURE,
RECOMMENDED_TOP_K,
RECOMMENDED_TOP_P,
)
from homeassistant.components.media_player import (
ATTR_MEDIA_CONTENT_ID,
DOMAIN as DOMAIN_MP,
SERVICE_PLAY_MEDIA,
)
from homeassistant.const import ATTR_ENTITY_ID, CONF_API_KEY, CONF_PLATFORM
from homeassistant.config_entries import ConfigSubentry
from homeassistant.const import ATTR_ENTITY_ID, CONF_API_KEY
from homeassistant.core import HomeAssistant, ServiceCall
from homeassistant.core_config import async_process_ha_core_config
from homeassistant.setup import async_setup_component
from . import API_ERROR_500
from tests.common import MockConfigEntry, async_mock_service
from tests.components.tts.common import retrieve_media
from tests.typing import ClientSessionGenerator
API_ERROR_500 = APIError("test", response=MagicMock())
TEST_CHAT_MODEL = "models/some-tts-model"
@pytest.fixture(autouse=True)
def tts_mutagen_mock_fixture_autouse(tts_mutagen_mock: MagicMock) -> None:
@ -63,20 +70,22 @@ def mock_genai_client() -> Generator[AsyncMock]:
"""Mock genai_client."""
client = Mock()
client.aio.models.get = AsyncMock()
client.models.generate_content.return_value = types.GenerateContentResponse(
candidates=(
types.Candidate(
content=types.Content(
parts=(
types.Part(
inline_data=types.Blob(
data=b"raw-audio-bytes",
mime_type="audio/L16;rate=24000",
)
),
client.aio.models.generate_content = AsyncMock(
return_value=types.GenerateContentResponse(
candidates=(
types.Candidate(
content=types.Content(
parts=(
types.Part(
inline_data=types.Blob(
data=b"raw-audio-bytes",
mime_type="audio/L16;rate=24000",
)
),
)
)
)
),
),
)
)
)
with patch(
@ -90,17 +99,29 @@ def mock_genai_client() -> Generator[AsyncMock]:
async def setup_fixture(
hass: HomeAssistant,
config: dict[str, Any],
request: pytest.FixtureRequest,
mock_genai_client: AsyncMock,
) -> None:
"""Set up the test environment."""
if request.param == "mock_setup":
await mock_setup(hass, config)
if request.param == "mock_config_entry_setup":
await mock_config_entry_setup(hass, config)
else:
raise RuntimeError("Invalid setup fixture")
config_entry = MockConfigEntry(domain=DOMAIN, data=config, version=2)
config_entry.add_to_hass(hass)
sub_entry = ConfigSubentry(
data={
tts.CONF_LANG: "en-US",
CONF_CHAT_MODEL: TEST_CHAT_MODEL,
},
subentry_type="tts",
title="Google AI TTS",
subentry_id="test_subentry_tts_id",
unique_id=None,
)
config_entry.runtime_data = mock_genai_client
hass.config_entries.async_add_subentry(config_entry, sub_entry)
await hass.config_entries.async_setup(config_entry.entry_id)
assert await async_setup_component(hass, DOMAIN, config)
await hass.async_block_till_done()
@ -112,105 +133,38 @@ def config_fixture() -> dict[str, Any]:
}
async def mock_setup(hass: HomeAssistant, config: dict[str, Any]) -> None:
"""Mock setup."""
assert await async_setup_component(
hass, tts.DOMAIN, {tts.DOMAIN: {CONF_PLATFORM: DOMAIN} | config}
)
async def mock_config_entry_setup(hass: HomeAssistant, config: dict[str, Any]) -> None:
"""Mock config entry setup."""
default_config = {tts.CONF_LANG: "en-US"}
config_entry = MockConfigEntry(
domain=DOMAIN, data=default_config | config, version=2
)
client_mock = Mock()
client_mock.models.get = None
client_mock.models.generate_content.return_value = types.GenerateContentResponse(
candidates=(
types.Candidate(
content=types.Content(
parts=(
types.Part(
inline_data=types.Blob(
data=b"raw-audio-bytes",
mime_type="audio/L16;rate=24000",
)
),
)
)
),
)
)
config_entry.runtime_data = client_mock
config_entry.add_to_hass(hass)
assert await hass.config_entries.async_setup(config_entry.entry_id)
@pytest.mark.parametrize(
("setup", "tts_service", "service_data"),
"service_data",
[
(
"mock_config_entry_setup",
"speak",
{
ATTR_ENTITY_ID: "tts.google_generative_ai_tts",
tts.ATTR_MEDIA_PLAYER_ENTITY_ID: "media_player.something",
tts.ATTR_MESSAGE: "There is a person at the front door.",
tts.ATTR_OPTIONS: {},
},
),
(
"mock_config_entry_setup",
"speak",
{
ATTR_ENTITY_ID: "tts.google_generative_ai_tts",
tts.ATTR_MEDIA_PLAYER_ENTITY_ID: "media_player.something",
tts.ATTR_MESSAGE: "There is a person at the front door.",
tts.ATTR_OPTIONS: {tts.ATTR_VOICE: "voice2"},
},
),
(
"mock_config_entry_setup",
"speak",
{
ATTR_ENTITY_ID: "tts.google_generative_ai_tts",
tts.ATTR_MEDIA_PLAYER_ENTITY_ID: "media_player.something",
tts.ATTR_MESSAGE: "There is a person at the front door.",
tts.ATTR_OPTIONS: {ATTR_MODEL: "model2"},
},
),
(
"mock_config_entry_setup",
"speak",
{
ATTR_ENTITY_ID: "tts.google_generative_ai_tts",
tts.ATTR_MEDIA_PLAYER_ENTITY_ID: "media_player.something",
tts.ATTR_MESSAGE: "There is a person at the front door.",
tts.ATTR_OPTIONS: {tts.ATTR_VOICE: "voice2", ATTR_MODEL: "model2"},
},
),
{
ATTR_ENTITY_ID: "tts.google_ai_tts",
tts.ATTR_MEDIA_PLAYER_ENTITY_ID: "media_player.something",
tts.ATTR_MESSAGE: "There is a person at the front door.",
tts.ATTR_OPTIONS: {},
},
{
ATTR_ENTITY_ID: "tts.google_ai_tts",
tts.ATTR_MEDIA_PLAYER_ENTITY_ID: "media_player.something",
tts.ATTR_MESSAGE: "There is a person at the front door.",
tts.ATTR_OPTIONS: {tts.ATTR_VOICE: "voice2"},
},
],
indirect=["setup"],
)
@pytest.mark.usefixtures("setup")
async def test_tts_service_speak(
setup: AsyncMock,
hass: HomeAssistant,
hass_client: ClientSessionGenerator,
calls: list[ServiceCall],
tts_service: str,
service_data: dict[str, Any],
) -> None:
"""Test tts service."""
tts_entity = hass.data[tts.DOMAIN].get_entity(service_data[ATTR_ENTITY_ID])
tts_entity._genai_client.models.generate_content.reset_mock()
tts_entity._genai_client.aio.models.generate_content.reset_mock()
await hass.services.async_call(
tts.DOMAIN,
tts_service,
"speak",
service_data,
blocking=True,
)
@ -221,10 +175,9 @@ async def test_tts_service_speak(
== HTTPStatus.OK
)
voice_id = service_data[tts.ATTR_OPTIONS].get(tts.ATTR_VOICE, "zephyr")
model_id = service_data[tts.ATTR_OPTIONS].get(ATTR_MODEL, RECOMMENDED_TTS_MODEL)
tts_entity._genai_client.models.generate_content.assert_called_once_with(
model=model_id,
tts_entity._genai_client.aio.models.generate_content.assert_called_once_with(
model=TEST_CHAT_MODEL,
contents="There is a person at the front door.",
config=types.GenerateContentConfig(
response_modalities=["AUDIO"],
@ -233,109 +186,52 @@ async def test_tts_service_speak(
prebuilt_voice_config=types.PrebuiltVoiceConfig(voice_name=voice_id)
)
),
temperature=RECOMMENDED_TEMPERATURE,
top_k=RECOMMENDED_TOP_K,
top_p=RECOMMENDED_TOP_P,
max_output_tokens=RECOMMENDED_MAX_TOKENS,
safety_settings=[
types.SafetySetting(
category=types.HarmCategory.HARM_CATEGORY_HATE_SPEECH,
threshold=RECOMMENDED_HARM_BLOCK_THRESHOLD,
),
types.SafetySetting(
category=types.HarmCategory.HARM_CATEGORY_HARASSMENT,
threshold=RECOMMENDED_HARM_BLOCK_THRESHOLD,
),
types.SafetySetting(
category=types.HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT,
threshold=RECOMMENDED_HARM_BLOCK_THRESHOLD,
),
types.SafetySetting(
category=types.HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT,
threshold=RECOMMENDED_HARM_BLOCK_THRESHOLD,
),
],
),
)
@pytest.mark.parametrize(
("setup", "tts_service", "service_data"),
[
(
"mock_config_entry_setup",
"speak",
{
ATTR_ENTITY_ID: "tts.google_generative_ai_tts",
tts.ATTR_MEDIA_PLAYER_ENTITY_ID: "media_player.something",
tts.ATTR_MESSAGE: "There is a person at the front door.",
tts.ATTR_LANGUAGE: "de-DE",
tts.ATTR_OPTIONS: {tts.ATTR_VOICE: "voice1"},
},
),
(
"mock_config_entry_setup",
"speak",
{
ATTR_ENTITY_ID: "tts.google_generative_ai_tts",
tts.ATTR_MEDIA_PLAYER_ENTITY_ID: "media_player.something",
tts.ATTR_MESSAGE: "There is a person at the front door.",
tts.ATTR_LANGUAGE: "it-IT",
tts.ATTR_OPTIONS: {tts.ATTR_VOICE: "voice1"},
},
),
],
indirect=["setup"],
)
async def test_tts_service_speak_lang_config(
setup: AsyncMock,
hass: HomeAssistant,
hass_client: ClientSessionGenerator,
calls: list[ServiceCall],
tts_service: str,
service_data: dict[str, Any],
) -> None:
"""Test service call with languages in the config."""
tts_entity = hass.data[tts.DOMAIN].get_entity(service_data[ATTR_ENTITY_ID])
tts_entity._genai_client.models.generate_content.reset_mock()
await hass.services.async_call(
tts.DOMAIN,
tts_service,
service_data,
blocking=True,
)
assert len(calls) == 1
assert (
await retrieve_media(hass, hass_client, calls[0].data[ATTR_MEDIA_CONTENT_ID])
== HTTPStatus.OK
)
tts_entity._genai_client.models.generate_content.assert_called_once_with(
model=RECOMMENDED_TTS_MODEL,
contents="There is a person at the front door.",
config=types.GenerateContentConfig(
response_modalities=["AUDIO"],
speech_config=types.SpeechConfig(
voice_config=types.VoiceConfig(
prebuilt_voice_config=types.PrebuiltVoiceConfig(voice_name="voice1")
)
),
),
)
@pytest.mark.parametrize(
("setup", "tts_service", "service_data"),
[
(
"mock_config_entry_setup",
"speak",
{
ATTR_ENTITY_ID: "tts.google_generative_ai_tts",
tts.ATTR_MEDIA_PLAYER_ENTITY_ID: "media_player.something",
tts.ATTR_MESSAGE: "There is a person at the front door.",
tts.ATTR_OPTIONS: {tts.ATTR_VOICE: "voice1"},
},
),
],
indirect=["setup"],
)
@pytest.mark.usefixtures("setup")
async def test_tts_service_speak_error(
setup: AsyncMock,
hass: HomeAssistant,
hass_client: ClientSessionGenerator,
calls: list[ServiceCall],
tts_service: str,
service_data: dict[str, Any],
) -> None:
"""Test service call with HTTP response 500."""
service_data = {
ATTR_ENTITY_ID: "tts.google_ai_tts",
tts.ATTR_MEDIA_PLAYER_ENTITY_ID: "media_player.something",
tts.ATTR_MESSAGE: "There is a person at the front door.",
tts.ATTR_OPTIONS: {tts.ATTR_VOICE: "voice1"},
}
tts_entity = hass.data[tts.DOMAIN].get_entity(service_data[ATTR_ENTITY_ID])
tts_entity._genai_client.models.generate_content.reset_mock()
tts_entity._genai_client.models.generate_content.side_effect = API_ERROR_500
tts_entity._genai_client.aio.models.generate_content.reset_mock()
tts_entity._genai_client.aio.models.generate_content.side_effect = API_ERROR_500
await hass.services.async_call(
tts.DOMAIN,
tts_service,
"speak",
service_data,
blocking=True,
)
@ -346,70 +242,39 @@ async def test_tts_service_speak_error(
== HTTPStatus.INTERNAL_SERVER_ERROR
)
tts_entity._genai_client.models.generate_content.assert_called_once_with(
model=RECOMMENDED_TTS_MODEL,
voice_id = service_data[tts.ATTR_OPTIONS].get(tts.ATTR_VOICE)
tts_entity._genai_client.aio.models.generate_content.assert_called_once_with(
model=TEST_CHAT_MODEL,
contents="There is a person at the front door.",
config=types.GenerateContentConfig(
response_modalities=["AUDIO"],
speech_config=types.SpeechConfig(
voice_config=types.VoiceConfig(
prebuilt_voice_config=types.PrebuiltVoiceConfig(voice_name="voice1")
)
),
),
)
@pytest.mark.parametrize(
("setup", "tts_service", "service_data"),
[
(
"mock_config_entry_setup",
"speak",
{
ATTR_ENTITY_ID: "tts.google_generative_ai_tts",
tts.ATTR_MEDIA_PLAYER_ENTITY_ID: "media_player.something",
tts.ATTR_MESSAGE: "There is a person at the front door.",
tts.ATTR_OPTIONS: {},
},
),
],
indirect=["setup"],
)
async def test_tts_service_speak_without_options(
setup: AsyncMock,
hass: HomeAssistant,
hass_client: ClientSessionGenerator,
calls: list[ServiceCall],
tts_service: str,
service_data: dict[str, Any],
) -> None:
"""Test service call with HTTP response 200."""
tts_entity = hass.data[tts.DOMAIN].get_entity(service_data[ATTR_ENTITY_ID])
tts_entity._genai_client.models.generate_content.reset_mock()
await hass.services.async_call(
tts.DOMAIN,
tts_service,
service_data,
blocking=True,
)
assert len(calls) == 1
assert (
await retrieve_media(hass, hass_client, calls[0].data[ATTR_MEDIA_CONTENT_ID])
== HTTPStatus.OK
)
tts_entity._genai_client.models.generate_content.assert_called_once_with(
model=RECOMMENDED_TTS_MODEL,
contents="There is a person at the front door.",
config=types.GenerateContentConfig(
response_modalities=["AUDIO"],
speech_config=types.SpeechConfig(
voice_config=types.VoiceConfig(
prebuilt_voice_config=types.PrebuiltVoiceConfig(voice_name="zephyr")
prebuilt_voice_config=types.PrebuiltVoiceConfig(voice_name=voice_id)
)
),
temperature=RECOMMENDED_TEMPERATURE,
top_k=RECOMMENDED_TOP_K,
top_p=RECOMMENDED_TOP_P,
max_output_tokens=RECOMMENDED_MAX_TOKENS,
safety_settings=[
types.SafetySetting(
category=types.HarmCategory.HARM_CATEGORY_HATE_SPEECH,
threshold=RECOMMENDED_HARM_BLOCK_THRESHOLD,
),
types.SafetySetting(
category=types.HarmCategory.HARM_CATEGORY_HARASSMENT,
threshold=RECOMMENDED_HARM_BLOCK_THRESHOLD,
),
types.SafetySetting(
category=types.HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT,
threshold=RECOMMENDED_HARM_BLOCK_THRESHOLD,
),
types.SafetySetting(
category=types.HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT,
threshold=RECOMMENDED_HARM_BLOCK_THRESHOLD,
),
],
),
)