mirror of
https://github.com/home-assistant/core.git
synced 2025-07-09 22:37:11 +00:00
Fixes in Google AI TTS (#147501)
* Fix Google AI not using correct config options after subentries migration * Fixes in Google AI TTS * Fix tests by @IvanLH * Change type name. --------- Co-authored-by: Paulus Schoutsen <balloob@gmail.com> Co-authored-by: Paulus Schoutsen <paulus@home-assistant.io>
This commit is contained in:
parent
345ec97dd5
commit
f0a78aadbe
@ -5,6 +5,7 @@ from __future__ import annotations
|
||||
import asyncio
|
||||
import mimetypes
|
||||
from pathlib import Path
|
||||
from types import MappingProxyType
|
||||
|
||||
from google.genai import Client
|
||||
from google.genai.errors import APIError, ClientError
|
||||
@ -36,10 +37,12 @@ from homeassistant.helpers.typing import ConfigType
|
||||
|
||||
from .const import (
|
||||
CONF_PROMPT,
|
||||
DEFAULT_TTS_NAME,
|
||||
DOMAIN,
|
||||
FILE_POLLING_INTERVAL_SECONDS,
|
||||
LOGGER,
|
||||
RECOMMENDED_CHAT_MODEL,
|
||||
RECOMMENDED_TTS_OPTIONS,
|
||||
TIMEOUT_MILLIS,
|
||||
)
|
||||
|
||||
@ -242,6 +245,16 @@ async def async_migrate_integration(hass: HomeAssistant) -> None:
|
||||
parent_entry = api_keys_entries[entry.data[CONF_API_KEY]]
|
||||
|
||||
hass.config_entries.async_add_subentry(parent_entry, subentry)
|
||||
if use_existing:
|
||||
hass.config_entries.async_add_subentry(
|
||||
parent_entry,
|
||||
ConfigSubentry(
|
||||
data=MappingProxyType(RECOMMENDED_TTS_OPTIONS),
|
||||
subentry_type="tts",
|
||||
title=DEFAULT_TTS_NAME,
|
||||
unique_id=None,
|
||||
),
|
||||
)
|
||||
conversation_entity = entity_registry.async_get_entity_id(
|
||||
"conversation",
|
||||
DOMAIN,
|
||||
|
@ -47,13 +47,17 @@ from .const import (
|
||||
CONF_TOP_P,
|
||||
CONF_USE_GOOGLE_SEARCH_TOOL,
|
||||
DEFAULT_CONVERSATION_NAME,
|
||||
DEFAULT_TTS_NAME,
|
||||
DOMAIN,
|
||||
RECOMMENDED_CHAT_MODEL,
|
||||
RECOMMENDED_CONVERSATION_OPTIONS,
|
||||
RECOMMENDED_HARM_BLOCK_THRESHOLD,
|
||||
RECOMMENDED_MAX_TOKENS,
|
||||
RECOMMENDED_TEMPERATURE,
|
||||
RECOMMENDED_TOP_K,
|
||||
RECOMMENDED_TOP_P,
|
||||
RECOMMENDED_TTS_MODEL,
|
||||
RECOMMENDED_TTS_OPTIONS,
|
||||
RECOMMENDED_USE_GOOGLE_SEARCH_TOOL,
|
||||
TIMEOUT_MILLIS,
|
||||
)
|
||||
@ -66,12 +70,6 @@ STEP_API_DATA_SCHEMA = vol.Schema(
|
||||
}
|
||||
)
|
||||
|
||||
RECOMMENDED_OPTIONS = {
|
||||
CONF_RECOMMENDED: True,
|
||||
CONF_LLM_HASS_API: [llm.LLM_API_ASSIST],
|
||||
CONF_PROMPT: llm.DEFAULT_INSTRUCTIONS_PROMPT,
|
||||
}
|
||||
|
||||
|
||||
async def validate_input(data: dict[str, Any]) -> None:
|
||||
"""Validate the user input allows us to connect.
|
||||
@ -123,10 +121,16 @@ class GoogleGenerativeAIConfigFlow(ConfigFlow, domain=DOMAIN):
|
||||
subentries=[
|
||||
{
|
||||
"subentry_type": "conversation",
|
||||
"data": RECOMMENDED_OPTIONS,
|
||||
"data": RECOMMENDED_CONVERSATION_OPTIONS,
|
||||
"title": DEFAULT_CONVERSATION_NAME,
|
||||
"unique_id": None,
|
||||
}
|
||||
},
|
||||
{
|
||||
"subentry_type": "tts",
|
||||
"data": RECOMMENDED_TTS_OPTIONS,
|
||||
"title": DEFAULT_TTS_NAME,
|
||||
"unique_id": None,
|
||||
},
|
||||
],
|
||||
)
|
||||
return self.async_show_form(
|
||||
@ -172,10 +176,13 @@ class GoogleGenerativeAIConfigFlow(ConfigFlow, domain=DOMAIN):
|
||||
cls, config_entry: ConfigEntry
|
||||
) -> dict[str, type[ConfigSubentryFlow]]:
|
||||
"""Return subentries supported by this integration."""
|
||||
return {"conversation": ConversationSubentryFlowHandler}
|
||||
return {
|
||||
"conversation": LLMSubentryFlowHandler,
|
||||
"tts": LLMSubentryFlowHandler,
|
||||
}
|
||||
|
||||
|
||||
class ConversationSubentryFlowHandler(ConfigSubentryFlow):
|
||||
class LLMSubentryFlowHandler(ConfigSubentryFlow):
|
||||
"""Flow for managing conversation subentries."""
|
||||
|
||||
last_rendered_recommended = False
|
||||
@ -202,7 +209,11 @@ class ConversationSubentryFlowHandler(ConfigSubentryFlow):
|
||||
|
||||
if user_input is None:
|
||||
if self._is_new:
|
||||
options = RECOMMENDED_OPTIONS.copy()
|
||||
options: dict[str, Any]
|
||||
if self._subentry_type == "tts":
|
||||
options = RECOMMENDED_TTS_OPTIONS.copy()
|
||||
else:
|
||||
options = RECOMMENDED_CONVERSATION_OPTIONS.copy()
|
||||
else:
|
||||
# If this is a reconfiguration, we need to copy the existing options
|
||||
# so that we can show the current values in the form.
|
||||
@ -216,7 +227,7 @@ class ConversationSubentryFlowHandler(ConfigSubentryFlow):
|
||||
if user_input[CONF_RECOMMENDED] == self.last_rendered_recommended:
|
||||
if not user_input.get(CONF_LLM_HASS_API):
|
||||
user_input.pop(CONF_LLM_HASS_API, None)
|
||||
# Don't allow to save options that enable the Google Seearch tool with an Assist API
|
||||
# Don't allow to save options that enable the Google Search tool with an Assist API
|
||||
if not (
|
||||
user_input.get(CONF_LLM_HASS_API)
|
||||
and user_input.get(CONF_USE_GOOGLE_SEARCH_TOOL, False) is True
|
||||
@ -240,7 +251,7 @@ class ConversationSubentryFlowHandler(ConfigSubentryFlow):
|
||||
options = user_input
|
||||
|
||||
schema = await google_generative_ai_config_option_schema(
|
||||
self.hass, self._is_new, options, self._genai_client
|
||||
self.hass, self._is_new, self._subentry_type, options, self._genai_client
|
||||
)
|
||||
return self.async_show_form(
|
||||
step_id="set_options", data_schema=vol.Schema(schema), errors=errors
|
||||
@ -253,6 +264,7 @@ class ConversationSubentryFlowHandler(ConfigSubentryFlow):
|
||||
async def google_generative_ai_config_option_schema(
|
||||
hass: HomeAssistant,
|
||||
is_new: bool,
|
||||
subentry_type: str,
|
||||
options: Mapping[str, Any],
|
||||
genai_client: genai.Client,
|
||||
) -> dict:
|
||||
@ -270,26 +282,39 @@ async def google_generative_ai_config_option_schema(
|
||||
suggested_llm_apis = [suggested_llm_apis]
|
||||
|
||||
if is_new:
|
||||
if CONF_NAME in options:
|
||||
default_name = options[CONF_NAME]
|
||||
elif subentry_type == "tts":
|
||||
default_name = DEFAULT_TTS_NAME
|
||||
else:
|
||||
default_name = DEFAULT_CONVERSATION_NAME
|
||||
schema: dict[vol.Required | vol.Optional, Any] = {
|
||||
vol.Required(CONF_NAME, default=DEFAULT_CONVERSATION_NAME): str,
|
||||
vol.Required(CONF_NAME, default=default_name): str,
|
||||
}
|
||||
else:
|
||||
schema = {}
|
||||
|
||||
if subentry_type == "conversation":
|
||||
schema.update(
|
||||
{
|
||||
vol.Optional(
|
||||
CONF_PROMPT,
|
||||
description={
|
||||
"suggested_value": options.get(
|
||||
CONF_PROMPT, llm.DEFAULT_INSTRUCTIONS_PROMPT
|
||||
)
|
||||
},
|
||||
): TemplateSelector(),
|
||||
vol.Optional(
|
||||
CONF_LLM_HASS_API,
|
||||
description={"suggested_value": suggested_llm_apis},
|
||||
): SelectSelector(
|
||||
SelectSelectorConfig(options=hass_apis, multiple=True)
|
||||
),
|
||||
}
|
||||
)
|
||||
schema.update(
|
||||
{
|
||||
vol.Optional(
|
||||
CONF_PROMPT,
|
||||
description={
|
||||
"suggested_value": options.get(
|
||||
CONF_PROMPT, llm.DEFAULT_INSTRUCTIONS_PROMPT
|
||||
)
|
||||
},
|
||||
): TemplateSelector(),
|
||||
vol.Optional(
|
||||
CONF_LLM_HASS_API,
|
||||
description={"suggested_value": suggested_llm_apis},
|
||||
): SelectSelector(SelectSelectorConfig(options=hass_apis, multiple=True)),
|
||||
vol.Required(
|
||||
CONF_RECOMMENDED, default=options.get(CONF_RECOMMENDED, False)
|
||||
): bool,
|
||||
@ -310,7 +335,7 @@ async def google_generative_ai_config_option_schema(
|
||||
if (
|
||||
api_model.display_name
|
||||
and api_model.name
|
||||
and "tts" not in api_model.name
|
||||
and ("tts" in api_model.name) == (subentry_type == "tts")
|
||||
and "vision" not in api_model.name
|
||||
and api_model.supported_actions
|
||||
and "generateContent" in api_model.supported_actions
|
||||
@ -341,12 +366,17 @@ async def google_generative_ai_config_option_schema(
|
||||
)
|
||||
)
|
||||
|
||||
if subentry_type == "tts":
|
||||
default_model = RECOMMENDED_TTS_MODEL
|
||||
else:
|
||||
default_model = RECOMMENDED_CHAT_MODEL
|
||||
|
||||
schema.update(
|
||||
{
|
||||
vol.Optional(
|
||||
CONF_CHAT_MODEL,
|
||||
description={"suggested_value": options.get(CONF_CHAT_MODEL)},
|
||||
default=RECOMMENDED_CHAT_MODEL,
|
||||
default=default_model,
|
||||
): SelectSelector(
|
||||
SelectSelectorConfig(mode=SelectSelectorMode.DROPDOWN, options=models)
|
||||
),
|
||||
@ -396,13 +426,18 @@ async def google_generative_ai_config_option_schema(
|
||||
},
|
||||
default=RECOMMENDED_HARM_BLOCK_THRESHOLD,
|
||||
): harm_block_thresholds_selector,
|
||||
vol.Optional(
|
||||
CONF_USE_GOOGLE_SEARCH_TOOL,
|
||||
description={
|
||||
"suggested_value": options.get(CONF_USE_GOOGLE_SEARCH_TOOL),
|
||||
},
|
||||
default=RECOMMENDED_USE_GOOGLE_SEARCH_TOOL,
|
||||
): bool,
|
||||
}
|
||||
)
|
||||
if subentry_type == "conversation":
|
||||
schema.update(
|
||||
{
|
||||
vol.Optional(
|
||||
CONF_USE_GOOGLE_SEARCH_TOOL,
|
||||
description={
|
||||
"suggested_value": options.get(CONF_USE_GOOGLE_SEARCH_TOOL),
|
||||
},
|
||||
default=RECOMMENDED_USE_GOOGLE_SEARCH_TOOL,
|
||||
): bool,
|
||||
}
|
||||
)
|
||||
return schema
|
||||
|
@ -2,17 +2,20 @@
|
||||
|
||||
import logging
|
||||
|
||||
from homeassistant.const import CONF_LLM_HASS_API
|
||||
from homeassistant.helpers import llm
|
||||
|
||||
DOMAIN = "google_generative_ai_conversation"
|
||||
LOGGER = logging.getLogger(__package__)
|
||||
CONF_PROMPT = "prompt"
|
||||
|
||||
DEFAULT_CONVERSATION_NAME = "Google AI Conversation"
|
||||
DEFAULT_TTS_NAME = "Google AI TTS"
|
||||
|
||||
ATTR_MODEL = "model"
|
||||
CONF_RECOMMENDED = "recommended"
|
||||
CONF_CHAT_MODEL = "chat_model"
|
||||
RECOMMENDED_CHAT_MODEL = "models/gemini-2.5-flash"
|
||||
RECOMMENDED_TTS_MODEL = "gemini-2.5-flash-preview-tts"
|
||||
RECOMMENDED_TTS_MODEL = "models/gemini-2.5-flash-preview-tts"
|
||||
CONF_TEMPERATURE = "temperature"
|
||||
RECOMMENDED_TEMPERATURE = 1.0
|
||||
CONF_TOP_P = "top_p"
|
||||
@ -31,3 +34,12 @@ RECOMMENDED_USE_GOOGLE_SEARCH_TOOL = False
|
||||
|
||||
TIMEOUT_MILLIS = 10000
|
||||
FILE_POLLING_INTERVAL_SECONDS = 0.05
|
||||
RECOMMENDED_CONVERSATION_OPTIONS = {
|
||||
CONF_PROMPT: llm.DEFAULT_INSTRUCTIONS_PROMPT,
|
||||
CONF_LLM_HASS_API: [llm.LLM_API_ASSIST],
|
||||
CONF_RECOMMENDED: True,
|
||||
}
|
||||
|
||||
RECOMMENDED_TTS_OPTIONS = {
|
||||
CONF_RECOMMENDED: True,
|
||||
}
|
||||
|
@ -29,7 +29,6 @@
|
||||
"reconfigure": "Reconfigure conversation agent"
|
||||
},
|
||||
"entry_type": "Conversation agent",
|
||||
|
||||
"step": {
|
||||
"set_options": {
|
||||
"data": {
|
||||
@ -61,6 +60,34 @@
|
||||
"error": {
|
||||
"invalid_google_search_option": "Google Search can only be enabled if nothing is selected in the \"Control Home Assistant\" setting."
|
||||
}
|
||||
},
|
||||
"tts": {
|
||||
"initiate_flow": {
|
||||
"user": "Add Text-to-Speech service",
|
||||
"reconfigure": "Reconfigure Text-to-Speech service"
|
||||
},
|
||||
"entry_type": "Text-to-Speech",
|
||||
"step": {
|
||||
"set_options": {
|
||||
"data": {
|
||||
"name": "[%key:common::config_flow::data::name%]",
|
||||
"recommended": "[%key:component::google_generative_ai_conversation::config_subentries::conversation::step::set_options::data::recommended%]",
|
||||
"chat_model": "[%key:common::generic::model%]",
|
||||
"temperature": "[%key:component::google_generative_ai_conversation::config_subentries::conversation::step::set_options::data::temperature%]",
|
||||
"top_p": "[%key:component::google_generative_ai_conversation::config_subentries::conversation::step::set_options::data::top_p%]",
|
||||
"top_k": "[%key:component::google_generative_ai_conversation::config_subentries::conversation::step::set_options::data::top_k%]",
|
||||
"max_tokens": "[%key:component::google_generative_ai_conversation::config_subentries::conversation::step::set_options::data::max_tokens%]",
|
||||
"harassment_block_threshold": "[%key:component::google_generative_ai_conversation::config_subentries::conversation::step::set_options::data::harassment_block_threshold%]",
|
||||
"hate_block_threshold": "[%key:component::google_generative_ai_conversation::config_subentries::conversation::step::set_options::data::hate_block_threshold%]",
|
||||
"sexual_block_threshold": "[%key:component::google_generative_ai_conversation::config_subentries::conversation::step::set_options::data::sexual_block_threshold%]",
|
||||
"dangerous_block_threshold": "[%key:component::google_generative_ai_conversation::config_subentries::conversation::step::set_options::data::dangerous_block_threshold%]"
|
||||
}
|
||||
}
|
||||
},
|
||||
"abort": {
|
||||
"entry_not_loaded": "[%key:component::google_generative_ai_conversation::config_subentries::conversation::abort::entry_not_loaded%]",
|
||||
"reconfigure_successful": "[%key:common::config_flow::abort::reconfigure_successful%]"
|
||||
}
|
||||
}
|
||||
},
|
||||
"services": {
|
||||
|
@ -2,13 +2,15 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Mapping
|
||||
from contextlib import suppress
|
||||
import io
|
||||
import logging
|
||||
from typing import Any
|
||||
import wave
|
||||
|
||||
from google.genai import types
|
||||
from google.genai.errors import APIError, ClientError
|
||||
from propcache.api import cached_property
|
||||
|
||||
from homeassistant.components.tts import (
|
||||
ATTR_VOICE,
|
||||
@ -19,12 +21,10 @@ from homeassistant.components.tts import (
|
||||
from homeassistant.config_entries import ConfigEntry
|
||||
from homeassistant.core import HomeAssistant, callback
|
||||
from homeassistant.exceptions import HomeAssistantError
|
||||
from homeassistant.helpers import device_registry as dr
|
||||
from homeassistant.helpers.entity_platform import AddConfigEntryEntitiesCallback
|
||||
|
||||
from .const import ATTR_MODEL, DOMAIN, RECOMMENDED_TTS_MODEL
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
from .const import CONF_CHAT_MODEL, LOGGER, RECOMMENDED_TTS_MODEL
|
||||
from .entity import GoogleGenerativeAILLMBaseEntity
|
||||
|
||||
|
||||
async def async_setup_entry(
|
||||
@ -32,15 +32,23 @@ async def async_setup_entry(
|
||||
config_entry: ConfigEntry,
|
||||
async_add_entities: AddConfigEntryEntitiesCallback,
|
||||
) -> None:
|
||||
"""Set up TTS entity."""
|
||||
tts_entity = GoogleGenerativeAITextToSpeechEntity(config_entry)
|
||||
async_add_entities([tts_entity])
|
||||
"""Set up TTS entities."""
|
||||
for subentry in config_entry.subentries.values():
|
||||
if subentry.subentry_type != "tts":
|
||||
continue
|
||||
|
||||
async_add_entities(
|
||||
[GoogleGenerativeAITextToSpeechEntity(config_entry, subentry)],
|
||||
config_subentry_id=subentry.subentry_id,
|
||||
)
|
||||
|
||||
|
||||
class GoogleGenerativeAITextToSpeechEntity(TextToSpeechEntity):
|
||||
class GoogleGenerativeAITextToSpeechEntity(
|
||||
TextToSpeechEntity, GoogleGenerativeAILLMBaseEntity
|
||||
):
|
||||
"""Google Generative AI text-to-speech entity."""
|
||||
|
||||
_attr_supported_options = [ATTR_VOICE, ATTR_MODEL]
|
||||
_attr_supported_options = [ATTR_VOICE]
|
||||
# See https://ai.google.dev/gemini-api/docs/speech-generation#languages
|
||||
_attr_supported_languages = [
|
||||
"ar-EG",
|
||||
@ -68,6 +76,8 @@ class GoogleGenerativeAITextToSpeechEntity(TextToSpeechEntity):
|
||||
"uk-UA",
|
||||
"vi-VN",
|
||||
]
|
||||
# Unused, but required by base class.
|
||||
# The Gemini TTS models detect the input language automatically.
|
||||
_attr_default_language = "en-US"
|
||||
# See https://ai.google.dev/gemini-api/docs/speech-generation#voices
|
||||
_supported_voices = [
|
||||
@ -106,53 +116,41 @@ class GoogleGenerativeAITextToSpeechEntity(TextToSpeechEntity):
|
||||
)
|
||||
]
|
||||
|
||||
def __init__(self, entry: ConfigEntry) -> None:
|
||||
"""Initialize Google Generative AI Conversation speech entity."""
|
||||
self.entry = entry
|
||||
self._attr_name = "Google Generative AI TTS"
|
||||
self._attr_unique_id = f"{entry.entry_id}_tts"
|
||||
self._attr_device_info = dr.DeviceInfo(
|
||||
identifiers={(DOMAIN, entry.entry_id)},
|
||||
manufacturer="Google",
|
||||
model="Generative AI",
|
||||
entry_type=dr.DeviceEntryType.SERVICE,
|
||||
)
|
||||
self._genai_client = entry.runtime_data
|
||||
self._default_voice_id = self._supported_voices[0].voice_id
|
||||
|
||||
@callback
|
||||
def async_get_supported_voices(self, language: str) -> list[Voice] | None:
|
||||
def async_get_supported_voices(self, language: str) -> list[Voice]:
|
||||
"""Return a list of supported voices for a language."""
|
||||
return self._supported_voices
|
||||
|
||||
@cached_property
|
||||
def default_options(self) -> Mapping[str, Any]:
|
||||
"""Return a mapping with the default options."""
|
||||
return {
|
||||
ATTR_VOICE: self._supported_voices[0].voice_id,
|
||||
}
|
||||
|
||||
async def async_get_tts_audio(
|
||||
self, message: str, language: str, options: dict[str, Any]
|
||||
) -> TtsAudioType:
|
||||
"""Load tts audio file from the engine."""
|
||||
try:
|
||||
response = self._genai_client.models.generate_content(
|
||||
model=options.get(ATTR_MODEL, RECOMMENDED_TTS_MODEL),
|
||||
contents=message,
|
||||
config=types.GenerateContentConfig(
|
||||
response_modalities=["AUDIO"],
|
||||
speech_config=types.SpeechConfig(
|
||||
voice_config=types.VoiceConfig(
|
||||
prebuilt_voice_config=types.PrebuiltVoiceConfig(
|
||||
voice_name=options.get(
|
||||
ATTR_VOICE, self._default_voice_id
|
||||
)
|
||||
)
|
||||
)
|
||||
),
|
||||
),
|
||||
config = self.create_generate_content_config()
|
||||
config.response_modalities = ["AUDIO"]
|
||||
config.speech_config = types.SpeechConfig(
|
||||
voice_config=types.VoiceConfig(
|
||||
prebuilt_voice_config=types.PrebuiltVoiceConfig(
|
||||
voice_name=options[ATTR_VOICE]
|
||||
)
|
||||
)
|
||||
)
|
||||
try:
|
||||
response = await self._genai_client.aio.models.generate_content(
|
||||
model=self.subentry.data.get(CONF_CHAT_MODEL, RECOMMENDED_TTS_MODEL),
|
||||
contents=message,
|
||||
config=config,
|
||||
)
|
||||
|
||||
data = response.candidates[0].content.parts[0].inline_data.data
|
||||
mime_type = response.candidates[0].content.parts[0].inline_data.mime_type
|
||||
except Exception as exc:
|
||||
_LOGGER.warning(
|
||||
"Error during processing of TTS request %s", exc, exc_info=True
|
||||
)
|
||||
except (APIError, ClientError, ValueError) as exc:
|
||||
LOGGER.error("Error during TTS: %s", exc, exc_info=True)
|
||||
raise HomeAssistantError(exc) from exc
|
||||
return "wav", self._convert_to_wav(data, mime_type)
|
||||
|
||||
@ -192,7 +190,7 @@ class GoogleGenerativeAITextToSpeechEntity(TextToSpeechEntity):
|
||||
|
||||
"""
|
||||
if not mime_type.startswith("audio/L"):
|
||||
_LOGGER.warning("Received unexpected MIME type %s", mime_type)
|
||||
LOGGER.warning("Received unexpected MIME type %s", mime_type)
|
||||
raise HomeAssistantError(f"Unsupported audio MIME type: {mime_type}")
|
||||
|
||||
bits_per_sample = 16
|
||||
|
@ -3420,6 +3420,11 @@ class ConfigSubentryFlow(
|
||||
"""Return config entry id."""
|
||||
return self.handler[0]
|
||||
|
||||
@property
|
||||
def _subentry_type(self) -> str:
|
||||
"""Return type of subentry we are editing/creating."""
|
||||
return self.handler[1]
|
||||
|
||||
@callback
|
||||
def _get_entry(self) -> ConfigEntry:
|
||||
"""Return the config entry linked to the current context."""
|
||||
|
@ -8,6 +8,7 @@ import pytest
|
||||
from homeassistant.components.google_generative_ai_conversation.const import (
|
||||
CONF_USE_GOOGLE_SEARCH_TOOL,
|
||||
DEFAULT_CONVERSATION_NAME,
|
||||
DEFAULT_TTS_NAME,
|
||||
)
|
||||
from homeassistant.config_entries import ConfigEntry
|
||||
from homeassistant.const import CONF_LLM_HASS_API
|
||||
@ -34,7 +35,13 @@ def mock_config_entry(hass: HomeAssistant) -> MockConfigEntry:
|
||||
"subentry_type": "conversation",
|
||||
"title": DEFAULT_CONVERSATION_NAME,
|
||||
"unique_id": None,
|
||||
}
|
||||
},
|
||||
{
|
||||
"data": {},
|
||||
"subentry_type": "tts",
|
||||
"title": DEFAULT_TTS_NAME,
|
||||
"unique_id": None,
|
||||
},
|
||||
],
|
||||
)
|
||||
entry.runtime_data = Mock()
|
||||
|
@ -6,9 +6,6 @@ import pytest
|
||||
from requests.exceptions import Timeout
|
||||
|
||||
from homeassistant import config_entries
|
||||
from homeassistant.components.google_generative_ai_conversation.config_flow import (
|
||||
RECOMMENDED_OPTIONS,
|
||||
)
|
||||
from homeassistant.components.google_generative_ai_conversation.const import (
|
||||
CONF_CHAT_MODEL,
|
||||
CONF_DANGEROUS_BLOCK_THRESHOLD,
|
||||
@ -23,12 +20,15 @@ from homeassistant.components.google_generative_ai_conversation.const import (
|
||||
CONF_TOP_P,
|
||||
CONF_USE_GOOGLE_SEARCH_TOOL,
|
||||
DEFAULT_CONVERSATION_NAME,
|
||||
DEFAULT_TTS_NAME,
|
||||
DOMAIN,
|
||||
RECOMMENDED_CHAT_MODEL,
|
||||
RECOMMENDED_CONVERSATION_OPTIONS,
|
||||
RECOMMENDED_HARM_BLOCK_THRESHOLD,
|
||||
RECOMMENDED_MAX_TOKENS,
|
||||
RECOMMENDED_TOP_K,
|
||||
RECOMMENDED_TOP_P,
|
||||
RECOMMENDED_TTS_OPTIONS,
|
||||
RECOMMENDED_USE_GOOGLE_SEARCH_TOOL,
|
||||
)
|
||||
from homeassistant.const import CONF_API_KEY, CONF_LLM_HASS_API, CONF_NAME
|
||||
@ -115,10 +115,16 @@ async def test_form(hass: HomeAssistant) -> None:
|
||||
assert result2["subentries"] == [
|
||||
{
|
||||
"subentry_type": "conversation",
|
||||
"data": RECOMMENDED_OPTIONS,
|
||||
"data": RECOMMENDED_CONVERSATION_OPTIONS,
|
||||
"title": DEFAULT_CONVERSATION_NAME,
|
||||
"unique_id": None,
|
||||
}
|
||||
},
|
||||
{
|
||||
"subentry_type": "tts",
|
||||
"data": RECOMMENDED_TTS_OPTIONS,
|
||||
"title": DEFAULT_TTS_NAME,
|
||||
"unique_id": None,
|
||||
},
|
||||
]
|
||||
assert len(mock_setup_entry.mock_calls) == 1
|
||||
|
||||
@ -172,19 +178,64 @@ async def test_creating_conversation_subentry(
|
||||
):
|
||||
result2 = await hass.config_entries.subentries.async_configure(
|
||||
result["flow_id"],
|
||||
{CONF_NAME: "Mock name", **RECOMMENDED_OPTIONS},
|
||||
{CONF_NAME: "Mock name", **RECOMMENDED_CONVERSATION_OPTIONS},
|
||||
)
|
||||
await hass.async_block_till_done()
|
||||
|
||||
assert result2["type"] is FlowResultType.CREATE_ENTRY
|
||||
assert result2["title"] == "Mock name"
|
||||
|
||||
processed_options = RECOMMENDED_OPTIONS.copy()
|
||||
processed_options = RECOMMENDED_CONVERSATION_OPTIONS.copy()
|
||||
processed_options[CONF_PROMPT] = processed_options[CONF_PROMPT].strip()
|
||||
|
||||
assert result2["data"] == processed_options
|
||||
|
||||
|
||||
async def test_creating_tts_subentry(
|
||||
hass: HomeAssistant,
|
||||
mock_init_component: None,
|
||||
mock_config_entry: MockConfigEntry,
|
||||
) -> None:
|
||||
"""Test creating a TTS subentry."""
|
||||
with patch(
|
||||
"google.genai.models.AsyncModels.list",
|
||||
return_value=get_models_pager(),
|
||||
):
|
||||
result = await hass.config_entries.subentries.async_init(
|
||||
(mock_config_entry.entry_id, "tts"),
|
||||
context={"source": config_entries.SOURCE_USER},
|
||||
)
|
||||
|
||||
assert result["type"] is FlowResultType.FORM, result
|
||||
assert result["step_id"] == "set_options"
|
||||
assert not result["errors"]
|
||||
|
||||
old_subentries = set(mock_config_entry.subentries)
|
||||
|
||||
with patch(
|
||||
"google.genai.models.AsyncModels.list",
|
||||
return_value=get_models_pager(),
|
||||
):
|
||||
result2 = await hass.config_entries.subentries.async_configure(
|
||||
result["flow_id"],
|
||||
{CONF_NAME: "Mock TTS", **RECOMMENDED_TTS_OPTIONS},
|
||||
)
|
||||
await hass.async_block_till_done()
|
||||
|
||||
assert result2["type"] is FlowResultType.CREATE_ENTRY
|
||||
assert result2["title"] == "Mock TTS"
|
||||
assert result2["data"] == RECOMMENDED_TTS_OPTIONS
|
||||
|
||||
assert len(mock_config_entry.subentries) == 3
|
||||
|
||||
new_subentry_id = list(set(mock_config_entry.subentries) - old_subentries)[0]
|
||||
new_subentry = mock_config_entry.subentries[new_subentry_id]
|
||||
|
||||
assert new_subentry.subentry_type == "tts"
|
||||
assert new_subentry.data == RECOMMENDED_TTS_OPTIONS
|
||||
assert new_subentry.title == "Mock TTS"
|
||||
|
||||
|
||||
async def test_creating_conversation_subentry_not_loaded(
|
||||
hass: HomeAssistant,
|
||||
mock_init_component: None,
|
||||
|
@ -7,7 +7,11 @@ import pytest
|
||||
from requests.exceptions import Timeout
|
||||
from syrupy.assertion import SnapshotAssertion
|
||||
|
||||
from homeassistant.components.google_generative_ai_conversation.const import DOMAIN
|
||||
from homeassistant.components.google_generative_ai_conversation.const import (
|
||||
DEFAULT_TTS_NAME,
|
||||
DOMAIN,
|
||||
RECOMMENDED_TTS_OPTIONS,
|
||||
)
|
||||
from homeassistant.config_entries import ConfigEntryState
|
||||
from homeassistant.const import CONF_API_KEY
|
||||
from homeassistant.core import HomeAssistant
|
||||
@ -469,13 +473,27 @@ async def test_migration_from_v1_to_v2(
|
||||
entry = entries[0]
|
||||
assert entry.version == 2
|
||||
assert not entry.options
|
||||
assert len(entry.subentries) == 2
|
||||
for subentry in entry.subentries.values():
|
||||
assert len(entry.subentries) == 3
|
||||
conversation_subentries = [
|
||||
subentry
|
||||
for subentry in entry.subentries.values()
|
||||
if subentry.subentry_type == "conversation"
|
||||
]
|
||||
assert len(conversation_subentries) == 2
|
||||
for subentry in conversation_subentries:
|
||||
assert subentry.subentry_type == "conversation"
|
||||
assert subentry.data == options
|
||||
assert "Google Generative AI" in subentry.title
|
||||
tts_subentries = [
|
||||
subentry
|
||||
for subentry in entry.subentries.values()
|
||||
if subentry.subentry_type == "tts"
|
||||
]
|
||||
assert len(tts_subentries) == 1
|
||||
assert tts_subentries[0].data == RECOMMENDED_TTS_OPTIONS
|
||||
assert tts_subentries[0].title == DEFAULT_TTS_NAME
|
||||
|
||||
subentry = list(entry.subentries.values())[0]
|
||||
subentry = conversation_subentries[0]
|
||||
|
||||
entity = entity_registry.async_get("conversation.google_generative_ai_conversation")
|
||||
assert entity.unique_id == subentry.subentry_id
|
||||
@ -493,7 +511,7 @@ async def test_migration_from_v1_to_v2(
|
||||
assert device.identifiers == {(DOMAIN, subentry.subentry_id)}
|
||||
assert device.id == device_1.id
|
||||
|
||||
subentry = list(entry.subentries.values())[1]
|
||||
subentry = conversation_subentries[1]
|
||||
|
||||
entity = entity_registry.async_get(
|
||||
"conversation.google_generative_ai_conversation_2"
|
||||
@ -591,11 +609,15 @@ async def test_migration_from_v1_to_v2_with_multiple_keys(
|
||||
for entry in entries:
|
||||
assert entry.version == 2
|
||||
assert not entry.options
|
||||
assert len(entry.subentries) == 1
|
||||
assert len(entry.subentries) == 2
|
||||
subentry = list(entry.subentries.values())[0]
|
||||
assert subentry.subentry_type == "conversation"
|
||||
assert subentry.data == options
|
||||
assert "Google Generative AI" in subentry.title
|
||||
subentry = list(entry.subentries.values())[1]
|
||||
assert subentry.subentry_type == "tts"
|
||||
assert subentry.data == RECOMMENDED_TTS_OPTIONS
|
||||
assert subentry.title == DEFAULT_TTS_NAME
|
||||
|
||||
dev = device_registry.async_get_device(
|
||||
identifiers={(DOMAIN, list(entry.subentries.values())[0].subentry_id)}
|
||||
@ -680,13 +702,27 @@ async def test_migration_from_v1_to_v2_with_same_keys(
|
||||
entry = entries[0]
|
||||
assert entry.version == 2
|
||||
assert not entry.options
|
||||
assert len(entry.subentries) == 2
|
||||
for subentry in entry.subentries.values():
|
||||
assert len(entry.subentries) == 3
|
||||
conversation_subentries = [
|
||||
subentry
|
||||
for subentry in entry.subentries.values()
|
||||
if subentry.subentry_type == "conversation"
|
||||
]
|
||||
assert len(conversation_subentries) == 2
|
||||
for subentry in conversation_subentries:
|
||||
assert subentry.subentry_type == "conversation"
|
||||
assert subentry.data == options
|
||||
assert "Google Generative AI" in subentry.title
|
||||
tts_subentries = [
|
||||
subentry
|
||||
for subentry in entry.subentries.values()
|
||||
if subentry.subentry_type == "tts"
|
||||
]
|
||||
assert len(tts_subentries) == 1
|
||||
assert tts_subentries[0].data == RECOMMENDED_TTS_OPTIONS
|
||||
assert tts_subentries[0].title == DEFAULT_TTS_NAME
|
||||
|
||||
subentry = list(entry.subentries.values())[0]
|
||||
subentry = conversation_subentries[0]
|
||||
|
||||
entity = entity_registry.async_get("conversation.google_generative_ai_conversation")
|
||||
assert entity.unique_id == subentry.subentry_id
|
||||
@ -704,7 +740,7 @@ async def test_migration_from_v1_to_v2_with_same_keys(
|
||||
assert device.identifiers == {(DOMAIN, subentry.subentry_id)}
|
||||
assert device.id == device_1.id
|
||||
|
||||
subentry = list(entry.subentries.values())[1]
|
||||
subentry = conversation_subentries[1]
|
||||
|
||||
entity = entity_registry.async_get(
|
||||
"conversation.google_generative_ai_conversation_2"
|
||||
|
@ -9,30 +9,37 @@ from typing import Any
|
||||
from unittest.mock import AsyncMock, MagicMock, Mock, patch
|
||||
|
||||
from google.genai import types
|
||||
from google.genai.errors import APIError
|
||||
import pytest
|
||||
|
||||
from homeassistant.components import tts
|
||||
from homeassistant.components.google_generative_ai_conversation.tts import (
|
||||
ATTR_MODEL,
|
||||
from homeassistant.components.google_generative_ai_conversation.const import (
|
||||
CONF_CHAT_MODEL,
|
||||
DOMAIN,
|
||||
RECOMMENDED_TTS_MODEL,
|
||||
RECOMMENDED_HARM_BLOCK_THRESHOLD,
|
||||
RECOMMENDED_MAX_TOKENS,
|
||||
RECOMMENDED_TEMPERATURE,
|
||||
RECOMMENDED_TOP_K,
|
||||
RECOMMENDED_TOP_P,
|
||||
)
|
||||
from homeassistant.components.media_player import (
|
||||
ATTR_MEDIA_CONTENT_ID,
|
||||
DOMAIN as DOMAIN_MP,
|
||||
SERVICE_PLAY_MEDIA,
|
||||
)
|
||||
from homeassistant.const import ATTR_ENTITY_ID, CONF_API_KEY, CONF_PLATFORM
|
||||
from homeassistant.config_entries import ConfigSubentry
|
||||
from homeassistant.const import ATTR_ENTITY_ID, CONF_API_KEY
|
||||
from homeassistant.core import HomeAssistant, ServiceCall
|
||||
from homeassistant.core_config import async_process_ha_core_config
|
||||
from homeassistant.setup import async_setup_component
|
||||
|
||||
from . import API_ERROR_500
|
||||
|
||||
from tests.common import MockConfigEntry, async_mock_service
|
||||
from tests.components.tts.common import retrieve_media
|
||||
from tests.typing import ClientSessionGenerator
|
||||
|
||||
API_ERROR_500 = APIError("test", response=MagicMock())
|
||||
TEST_CHAT_MODEL = "models/some-tts-model"
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def tts_mutagen_mock_fixture_autouse(tts_mutagen_mock: MagicMock) -> None:
|
||||
@ -63,20 +70,22 @@ def mock_genai_client() -> Generator[AsyncMock]:
|
||||
"""Mock genai_client."""
|
||||
client = Mock()
|
||||
client.aio.models.get = AsyncMock()
|
||||
client.models.generate_content.return_value = types.GenerateContentResponse(
|
||||
candidates=(
|
||||
types.Candidate(
|
||||
content=types.Content(
|
||||
parts=(
|
||||
types.Part(
|
||||
inline_data=types.Blob(
|
||||
data=b"raw-audio-bytes",
|
||||
mime_type="audio/L16;rate=24000",
|
||||
)
|
||||
),
|
||||
client.aio.models.generate_content = AsyncMock(
|
||||
return_value=types.GenerateContentResponse(
|
||||
candidates=(
|
||||
types.Candidate(
|
||||
content=types.Content(
|
||||
parts=(
|
||||
types.Part(
|
||||
inline_data=types.Blob(
|
||||
data=b"raw-audio-bytes",
|
||||
mime_type="audio/L16;rate=24000",
|
||||
)
|
||||
),
|
||||
)
|
||||
)
|
||||
)
|
||||
),
|
||||
),
|
||||
)
|
||||
)
|
||||
)
|
||||
with patch(
|
||||
@ -90,17 +99,29 @@ def mock_genai_client() -> Generator[AsyncMock]:
|
||||
async def setup_fixture(
|
||||
hass: HomeAssistant,
|
||||
config: dict[str, Any],
|
||||
request: pytest.FixtureRequest,
|
||||
mock_genai_client: AsyncMock,
|
||||
) -> None:
|
||||
"""Set up the test environment."""
|
||||
if request.param == "mock_setup":
|
||||
await mock_setup(hass, config)
|
||||
if request.param == "mock_config_entry_setup":
|
||||
await mock_config_entry_setup(hass, config)
|
||||
else:
|
||||
raise RuntimeError("Invalid setup fixture")
|
||||
config_entry = MockConfigEntry(domain=DOMAIN, data=config, version=2)
|
||||
config_entry.add_to_hass(hass)
|
||||
|
||||
sub_entry = ConfigSubentry(
|
||||
data={
|
||||
tts.CONF_LANG: "en-US",
|
||||
CONF_CHAT_MODEL: TEST_CHAT_MODEL,
|
||||
},
|
||||
subentry_type="tts",
|
||||
title="Google AI TTS",
|
||||
subentry_id="test_subentry_tts_id",
|
||||
unique_id=None,
|
||||
)
|
||||
|
||||
config_entry.runtime_data = mock_genai_client
|
||||
|
||||
hass.config_entries.async_add_subentry(config_entry, sub_entry)
|
||||
await hass.config_entries.async_setup(config_entry.entry_id)
|
||||
|
||||
assert await async_setup_component(hass, DOMAIN, config)
|
||||
await hass.async_block_till_done()
|
||||
|
||||
|
||||
@ -112,105 +133,38 @@ def config_fixture() -> dict[str, Any]:
|
||||
}
|
||||
|
||||
|
||||
async def mock_setup(hass: HomeAssistant, config: dict[str, Any]) -> None:
|
||||
"""Mock setup."""
|
||||
assert await async_setup_component(
|
||||
hass, tts.DOMAIN, {tts.DOMAIN: {CONF_PLATFORM: DOMAIN} | config}
|
||||
)
|
||||
|
||||
|
||||
async def mock_config_entry_setup(hass: HomeAssistant, config: dict[str, Any]) -> None:
|
||||
"""Mock config entry setup."""
|
||||
default_config = {tts.CONF_LANG: "en-US"}
|
||||
config_entry = MockConfigEntry(
|
||||
domain=DOMAIN, data=default_config | config, version=2
|
||||
)
|
||||
|
||||
client_mock = Mock()
|
||||
client_mock.models.get = None
|
||||
client_mock.models.generate_content.return_value = types.GenerateContentResponse(
|
||||
candidates=(
|
||||
types.Candidate(
|
||||
content=types.Content(
|
||||
parts=(
|
||||
types.Part(
|
||||
inline_data=types.Blob(
|
||||
data=b"raw-audio-bytes",
|
||||
mime_type="audio/L16;rate=24000",
|
||||
)
|
||||
),
|
||||
)
|
||||
)
|
||||
),
|
||||
)
|
||||
)
|
||||
config_entry.runtime_data = client_mock
|
||||
config_entry.add_to_hass(hass)
|
||||
|
||||
assert await hass.config_entries.async_setup(config_entry.entry_id)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("setup", "tts_service", "service_data"),
|
||||
"service_data",
|
||||
[
|
||||
(
|
||||
"mock_config_entry_setup",
|
||||
"speak",
|
||||
{
|
||||
ATTR_ENTITY_ID: "tts.google_generative_ai_tts",
|
||||
tts.ATTR_MEDIA_PLAYER_ENTITY_ID: "media_player.something",
|
||||
tts.ATTR_MESSAGE: "There is a person at the front door.",
|
||||
tts.ATTR_OPTIONS: {},
|
||||
},
|
||||
),
|
||||
(
|
||||
"mock_config_entry_setup",
|
||||
"speak",
|
||||
{
|
||||
ATTR_ENTITY_ID: "tts.google_generative_ai_tts",
|
||||
tts.ATTR_MEDIA_PLAYER_ENTITY_ID: "media_player.something",
|
||||
tts.ATTR_MESSAGE: "There is a person at the front door.",
|
||||
tts.ATTR_OPTIONS: {tts.ATTR_VOICE: "voice2"},
|
||||
},
|
||||
),
|
||||
(
|
||||
"mock_config_entry_setup",
|
||||
"speak",
|
||||
{
|
||||
ATTR_ENTITY_ID: "tts.google_generative_ai_tts",
|
||||
tts.ATTR_MEDIA_PLAYER_ENTITY_ID: "media_player.something",
|
||||
tts.ATTR_MESSAGE: "There is a person at the front door.",
|
||||
tts.ATTR_OPTIONS: {ATTR_MODEL: "model2"},
|
||||
},
|
||||
),
|
||||
(
|
||||
"mock_config_entry_setup",
|
||||
"speak",
|
||||
{
|
||||
ATTR_ENTITY_ID: "tts.google_generative_ai_tts",
|
||||
tts.ATTR_MEDIA_PLAYER_ENTITY_ID: "media_player.something",
|
||||
tts.ATTR_MESSAGE: "There is a person at the front door.",
|
||||
tts.ATTR_OPTIONS: {tts.ATTR_VOICE: "voice2", ATTR_MODEL: "model2"},
|
||||
},
|
||||
),
|
||||
{
|
||||
ATTR_ENTITY_ID: "tts.google_ai_tts",
|
||||
tts.ATTR_MEDIA_PLAYER_ENTITY_ID: "media_player.something",
|
||||
tts.ATTR_MESSAGE: "There is a person at the front door.",
|
||||
tts.ATTR_OPTIONS: {},
|
||||
},
|
||||
{
|
||||
ATTR_ENTITY_ID: "tts.google_ai_tts",
|
||||
tts.ATTR_MEDIA_PLAYER_ENTITY_ID: "media_player.something",
|
||||
tts.ATTR_MESSAGE: "There is a person at the front door.",
|
||||
tts.ATTR_OPTIONS: {tts.ATTR_VOICE: "voice2"},
|
||||
},
|
||||
],
|
||||
indirect=["setup"],
|
||||
)
|
||||
@pytest.mark.usefixtures("setup")
|
||||
async def test_tts_service_speak(
|
||||
setup: AsyncMock,
|
||||
hass: HomeAssistant,
|
||||
hass_client: ClientSessionGenerator,
|
||||
calls: list[ServiceCall],
|
||||
tts_service: str,
|
||||
service_data: dict[str, Any],
|
||||
) -> None:
|
||||
"""Test tts service."""
|
||||
|
||||
tts_entity = hass.data[tts.DOMAIN].get_entity(service_data[ATTR_ENTITY_ID])
|
||||
tts_entity._genai_client.models.generate_content.reset_mock()
|
||||
tts_entity._genai_client.aio.models.generate_content.reset_mock()
|
||||
|
||||
await hass.services.async_call(
|
||||
tts.DOMAIN,
|
||||
tts_service,
|
||||
"speak",
|
||||
service_data,
|
||||
blocking=True,
|
||||
)
|
||||
@ -221,10 +175,9 @@ async def test_tts_service_speak(
|
||||
== HTTPStatus.OK
|
||||
)
|
||||
voice_id = service_data[tts.ATTR_OPTIONS].get(tts.ATTR_VOICE, "zephyr")
|
||||
model_id = service_data[tts.ATTR_OPTIONS].get(ATTR_MODEL, RECOMMENDED_TTS_MODEL)
|
||||
|
||||
tts_entity._genai_client.models.generate_content.assert_called_once_with(
|
||||
model=model_id,
|
||||
tts_entity._genai_client.aio.models.generate_content.assert_called_once_with(
|
||||
model=TEST_CHAT_MODEL,
|
||||
contents="There is a person at the front door.",
|
||||
config=types.GenerateContentConfig(
|
||||
response_modalities=["AUDIO"],
|
||||
@ -233,109 +186,52 @@ async def test_tts_service_speak(
|
||||
prebuilt_voice_config=types.PrebuiltVoiceConfig(voice_name=voice_id)
|
||||
)
|
||||
),
|
||||
temperature=RECOMMENDED_TEMPERATURE,
|
||||
top_k=RECOMMENDED_TOP_K,
|
||||
top_p=RECOMMENDED_TOP_P,
|
||||
max_output_tokens=RECOMMENDED_MAX_TOKENS,
|
||||
safety_settings=[
|
||||
types.SafetySetting(
|
||||
category=types.HarmCategory.HARM_CATEGORY_HATE_SPEECH,
|
||||
threshold=RECOMMENDED_HARM_BLOCK_THRESHOLD,
|
||||
),
|
||||
types.SafetySetting(
|
||||
category=types.HarmCategory.HARM_CATEGORY_HARASSMENT,
|
||||
threshold=RECOMMENDED_HARM_BLOCK_THRESHOLD,
|
||||
),
|
||||
types.SafetySetting(
|
||||
category=types.HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT,
|
||||
threshold=RECOMMENDED_HARM_BLOCK_THRESHOLD,
|
||||
),
|
||||
types.SafetySetting(
|
||||
category=types.HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT,
|
||||
threshold=RECOMMENDED_HARM_BLOCK_THRESHOLD,
|
||||
),
|
||||
],
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("setup", "tts_service", "service_data"),
|
||||
[
|
||||
(
|
||||
"mock_config_entry_setup",
|
||||
"speak",
|
||||
{
|
||||
ATTR_ENTITY_ID: "tts.google_generative_ai_tts",
|
||||
tts.ATTR_MEDIA_PLAYER_ENTITY_ID: "media_player.something",
|
||||
tts.ATTR_MESSAGE: "There is a person at the front door.",
|
||||
tts.ATTR_LANGUAGE: "de-DE",
|
||||
tts.ATTR_OPTIONS: {tts.ATTR_VOICE: "voice1"},
|
||||
},
|
||||
),
|
||||
(
|
||||
"mock_config_entry_setup",
|
||||
"speak",
|
||||
{
|
||||
ATTR_ENTITY_ID: "tts.google_generative_ai_tts",
|
||||
tts.ATTR_MEDIA_PLAYER_ENTITY_ID: "media_player.something",
|
||||
tts.ATTR_MESSAGE: "There is a person at the front door.",
|
||||
tts.ATTR_LANGUAGE: "it-IT",
|
||||
tts.ATTR_OPTIONS: {tts.ATTR_VOICE: "voice1"},
|
||||
},
|
||||
),
|
||||
],
|
||||
indirect=["setup"],
|
||||
)
|
||||
async def test_tts_service_speak_lang_config(
|
||||
setup: AsyncMock,
|
||||
hass: HomeAssistant,
|
||||
hass_client: ClientSessionGenerator,
|
||||
calls: list[ServiceCall],
|
||||
tts_service: str,
|
||||
service_data: dict[str, Any],
|
||||
) -> None:
|
||||
"""Test service call with languages in the config."""
|
||||
tts_entity = hass.data[tts.DOMAIN].get_entity(service_data[ATTR_ENTITY_ID])
|
||||
tts_entity._genai_client.models.generate_content.reset_mock()
|
||||
|
||||
await hass.services.async_call(
|
||||
tts.DOMAIN,
|
||||
tts_service,
|
||||
service_data,
|
||||
blocking=True,
|
||||
)
|
||||
|
||||
assert len(calls) == 1
|
||||
assert (
|
||||
await retrieve_media(hass, hass_client, calls[0].data[ATTR_MEDIA_CONTENT_ID])
|
||||
== HTTPStatus.OK
|
||||
)
|
||||
|
||||
tts_entity._genai_client.models.generate_content.assert_called_once_with(
|
||||
model=RECOMMENDED_TTS_MODEL,
|
||||
contents="There is a person at the front door.",
|
||||
config=types.GenerateContentConfig(
|
||||
response_modalities=["AUDIO"],
|
||||
speech_config=types.SpeechConfig(
|
||||
voice_config=types.VoiceConfig(
|
||||
prebuilt_voice_config=types.PrebuiltVoiceConfig(voice_name="voice1")
|
||||
)
|
||||
),
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("setup", "tts_service", "service_data"),
|
||||
[
|
||||
(
|
||||
"mock_config_entry_setup",
|
||||
"speak",
|
||||
{
|
||||
ATTR_ENTITY_ID: "tts.google_generative_ai_tts",
|
||||
tts.ATTR_MEDIA_PLAYER_ENTITY_ID: "media_player.something",
|
||||
tts.ATTR_MESSAGE: "There is a person at the front door.",
|
||||
tts.ATTR_OPTIONS: {tts.ATTR_VOICE: "voice1"},
|
||||
},
|
||||
),
|
||||
],
|
||||
indirect=["setup"],
|
||||
)
|
||||
@pytest.mark.usefixtures("setup")
|
||||
async def test_tts_service_speak_error(
|
||||
setup: AsyncMock,
|
||||
hass: HomeAssistant,
|
||||
hass_client: ClientSessionGenerator,
|
||||
calls: list[ServiceCall],
|
||||
tts_service: str,
|
||||
service_data: dict[str, Any],
|
||||
) -> None:
|
||||
"""Test service call with HTTP response 500."""
|
||||
service_data = {
|
||||
ATTR_ENTITY_ID: "tts.google_ai_tts",
|
||||
tts.ATTR_MEDIA_PLAYER_ENTITY_ID: "media_player.something",
|
||||
tts.ATTR_MESSAGE: "There is a person at the front door.",
|
||||
tts.ATTR_OPTIONS: {tts.ATTR_VOICE: "voice1"},
|
||||
}
|
||||
tts_entity = hass.data[tts.DOMAIN].get_entity(service_data[ATTR_ENTITY_ID])
|
||||
tts_entity._genai_client.models.generate_content.reset_mock()
|
||||
tts_entity._genai_client.models.generate_content.side_effect = API_ERROR_500
|
||||
tts_entity._genai_client.aio.models.generate_content.reset_mock()
|
||||
tts_entity._genai_client.aio.models.generate_content.side_effect = API_ERROR_500
|
||||
|
||||
await hass.services.async_call(
|
||||
tts.DOMAIN,
|
||||
tts_service,
|
||||
"speak",
|
||||
service_data,
|
||||
blocking=True,
|
||||
)
|
||||
@ -346,70 +242,39 @@ async def test_tts_service_speak_error(
|
||||
== HTTPStatus.INTERNAL_SERVER_ERROR
|
||||
)
|
||||
|
||||
tts_entity._genai_client.models.generate_content.assert_called_once_with(
|
||||
model=RECOMMENDED_TTS_MODEL,
|
||||
voice_id = service_data[tts.ATTR_OPTIONS].get(tts.ATTR_VOICE)
|
||||
|
||||
tts_entity._genai_client.aio.models.generate_content.assert_called_once_with(
|
||||
model=TEST_CHAT_MODEL,
|
||||
contents="There is a person at the front door.",
|
||||
config=types.GenerateContentConfig(
|
||||
response_modalities=["AUDIO"],
|
||||
speech_config=types.SpeechConfig(
|
||||
voice_config=types.VoiceConfig(
|
||||
prebuilt_voice_config=types.PrebuiltVoiceConfig(voice_name="voice1")
|
||||
)
|
||||
),
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("setup", "tts_service", "service_data"),
|
||||
[
|
||||
(
|
||||
"mock_config_entry_setup",
|
||||
"speak",
|
||||
{
|
||||
ATTR_ENTITY_ID: "tts.google_generative_ai_tts",
|
||||
tts.ATTR_MEDIA_PLAYER_ENTITY_ID: "media_player.something",
|
||||
tts.ATTR_MESSAGE: "There is a person at the front door.",
|
||||
tts.ATTR_OPTIONS: {},
|
||||
},
|
||||
),
|
||||
],
|
||||
indirect=["setup"],
|
||||
)
|
||||
async def test_tts_service_speak_without_options(
|
||||
setup: AsyncMock,
|
||||
hass: HomeAssistant,
|
||||
hass_client: ClientSessionGenerator,
|
||||
calls: list[ServiceCall],
|
||||
tts_service: str,
|
||||
service_data: dict[str, Any],
|
||||
) -> None:
|
||||
"""Test service call with HTTP response 200."""
|
||||
tts_entity = hass.data[tts.DOMAIN].get_entity(service_data[ATTR_ENTITY_ID])
|
||||
tts_entity._genai_client.models.generate_content.reset_mock()
|
||||
|
||||
await hass.services.async_call(
|
||||
tts.DOMAIN,
|
||||
tts_service,
|
||||
service_data,
|
||||
blocking=True,
|
||||
)
|
||||
|
||||
assert len(calls) == 1
|
||||
assert (
|
||||
await retrieve_media(hass, hass_client, calls[0].data[ATTR_MEDIA_CONTENT_ID])
|
||||
== HTTPStatus.OK
|
||||
)
|
||||
|
||||
tts_entity._genai_client.models.generate_content.assert_called_once_with(
|
||||
model=RECOMMENDED_TTS_MODEL,
|
||||
contents="There is a person at the front door.",
|
||||
config=types.GenerateContentConfig(
|
||||
response_modalities=["AUDIO"],
|
||||
speech_config=types.SpeechConfig(
|
||||
voice_config=types.VoiceConfig(
|
||||
prebuilt_voice_config=types.PrebuiltVoiceConfig(voice_name="zephyr")
|
||||
prebuilt_voice_config=types.PrebuiltVoiceConfig(voice_name=voice_id)
|
||||
)
|
||||
),
|
||||
temperature=RECOMMENDED_TEMPERATURE,
|
||||
top_k=RECOMMENDED_TOP_K,
|
||||
top_p=RECOMMENDED_TOP_P,
|
||||
max_output_tokens=RECOMMENDED_MAX_TOKENS,
|
||||
safety_settings=[
|
||||
types.SafetySetting(
|
||||
category=types.HarmCategory.HARM_CATEGORY_HATE_SPEECH,
|
||||
threshold=RECOMMENDED_HARM_BLOCK_THRESHOLD,
|
||||
),
|
||||
types.SafetySetting(
|
||||
category=types.HarmCategory.HARM_CATEGORY_HARASSMENT,
|
||||
threshold=RECOMMENDED_HARM_BLOCK_THRESHOLD,
|
||||
),
|
||||
types.SafetySetting(
|
||||
category=types.HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT,
|
||||
threshold=RECOMMENDED_HARM_BLOCK_THRESHOLD,
|
||||
),
|
||||
types.SafetySetting(
|
||||
category=types.HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT,
|
||||
threshold=RECOMMENDED_HARM_BLOCK_THRESHOLD,
|
||||
),
|
||||
],
|
||||
),
|
||||
)
|
||||
|
Loading…
x
Reference in New Issue
Block a user