Add support for Gemini's new TTS capabilities (#145872)

* Add support for Gemini TTS

* Add tests

* Use wave library and update a few comments
This commit is contained in:
Markus Lanthaler 2025-06-15 07:21:04 +02:00 committed by GitHub
parent ec02f6d010
commit c988d1ce36
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 635 additions and 1 deletions

View File

@ -45,7 +45,10 @@ CONF_IMAGE_FILENAME = "image_filename"
CONF_FILENAMES = "filenames"
CONFIG_SCHEMA = cv.config_entry_only_config_schema(DOMAIN)
PLATFORMS = (Platform.CONVERSATION,)
PLATFORMS = (
Platform.CONVERSATION,
Platform.TTS,
)
type GoogleGenerativeAIConfigEntry = ConfigEntry[Client]

View File

@ -6,9 +6,11 @@ DOMAIN = "google_generative_ai_conversation"
LOGGER = logging.getLogger(__package__)
CONF_PROMPT = "prompt"
ATTR_MODEL = "model"
CONF_RECOMMENDED = "recommended"
CONF_CHAT_MODEL = "chat_model"
RECOMMENDED_CHAT_MODEL = "models/gemini-2.0-flash"
RECOMMENDED_TTS_MODEL = "gemini-2.5-flash-preview-tts"
CONF_TEMPERATURE = "temperature"
RECOMMENDED_TEMPERATURE = 1.0
CONF_TOP_P = "top_p"

View File

@ -0,0 +1,216 @@
"""Text to speech support for Google Generative AI."""
from __future__ import annotations
from contextlib import suppress
import io
import logging
from typing import Any
import wave
from google.genai import types
from homeassistant.components.tts import (
ATTR_VOICE,
TextToSpeechEntity,
TtsAudioType,
Voice,
)
from homeassistant.config_entries import ConfigEntry
from homeassistant.core import HomeAssistant, callback
from homeassistant.exceptions import HomeAssistantError
from homeassistant.helpers import device_registry as dr
from homeassistant.helpers.entity_platform import AddConfigEntryEntitiesCallback
from .const import ATTR_MODEL, DOMAIN, RECOMMENDED_TTS_MODEL
_LOGGER = logging.getLogger(__name__)
async def async_setup_entry(
hass: HomeAssistant,
config_entry: ConfigEntry,
async_add_entities: AddConfigEntryEntitiesCallback,
) -> None:
"""Set up TTS entity."""
tts_entity = GoogleGenerativeAITextToSpeechEntity(config_entry)
async_add_entities([tts_entity])
class GoogleGenerativeAITextToSpeechEntity(TextToSpeechEntity):
"""Google Generative AI text-to-speech entity."""
_attr_supported_options = [ATTR_VOICE, ATTR_MODEL]
# See https://ai.google.dev/gemini-api/docs/speech-generation#languages
_attr_supported_languages = [
"ar-EG",
"bn-BD",
"de-DE",
"en-IN",
"en-US",
"es-US",
"fr-FR",
"hi-IN",
"id-ID",
"it-IT",
"ja-JP",
"ko-KR",
"mr-IN",
"nl-NL",
"pl-PL",
"pt-BR",
"ro-RO",
"ru-RU",
"ta-IN",
"te-IN",
"th-TH",
"tr-TR",
"uk-UA",
"vi-VN",
]
_attr_default_language = "en-US"
# See https://ai.google.dev/gemini-api/docs/speech-generation#voices
_supported_voices = [
Voice(voice.split(" ", 1)[0].lower(), voice)
for voice in (
"Zephyr (Bright)",
"Puck (Upbeat)",
"Charon (Informative)",
"Kore (Firm)",
"Fenrir (Excitable)",
"Leda (Youthful)",
"Orus (Firm)",
"Aoede (Breezy)",
"Callirrhoe (Easy-going)",
"Autonoe (Bright)",
"Enceladus (Breathy)",
"Iapetus (Clear)",
"Umbriel (Easy-going)",
"Algieba (Smooth)",
"Despina (Smooth)",
"Erinome (Clear)",
"Algenib (Gravelly)",
"Rasalgethi (Informative)",
"Laomedeia (Upbeat)",
"Achernar (Soft)",
"Alnilam (Firm)",
"Schedar (Even)",
"Gacrux (Mature)",
"Pulcherrima (Forward)",
"Achird (Friendly)",
"Zubenelgenubi (Casual)",
"Vindemiatrix (Gentle)",
"Sadachbia (Lively)",
"Sadaltager (Knowledgeable)",
"Sulafat (Warm)",
)
]
def __init__(self, entry: ConfigEntry) -> None:
"""Initialize Google Generative AI Conversation speech entity."""
self.entry = entry
self._attr_name = "Google Generative AI TTS"
self._attr_unique_id = f"{entry.entry_id}_tts"
self._attr_device_info = dr.DeviceInfo(
identifiers={(DOMAIN, entry.entry_id)},
name=entry.title,
manufacturer="Google",
model="Generative AI",
entry_type=dr.DeviceEntryType.SERVICE,
)
self._genai_client = entry.runtime_data
self._default_voice_id = self._supported_voices[0].voice_id
@callback
def async_get_supported_voices(self, language: str) -> list[Voice] | None:
"""Return a list of supported voices for a language."""
return self._supported_voices
async def async_get_tts_audio(
self, message: str, language: str, options: dict[str, Any]
) -> TtsAudioType:
"""Load tts audio file from the engine."""
try:
response = self._genai_client.models.generate_content(
model=options.get(ATTR_MODEL, RECOMMENDED_TTS_MODEL),
contents=message,
config=types.GenerateContentConfig(
response_modalities=["AUDIO"],
speech_config=types.SpeechConfig(
voice_config=types.VoiceConfig(
prebuilt_voice_config=types.PrebuiltVoiceConfig(
voice_name=options.get(
ATTR_VOICE, self._default_voice_id
)
)
)
),
),
)
data = response.candidates[0].content.parts[0].inline_data.data
mime_type = response.candidates[0].content.parts[0].inline_data.mime_type
except Exception as exc:
_LOGGER.warning(
"Error during processing of TTS request %s", exc, exc_info=True
)
raise HomeAssistantError(exc) from exc
return "wav", self._convert_to_wav(data, mime_type)
def _convert_to_wav(self, audio_data: bytes, mime_type: str) -> bytes:
"""Generate a WAV file header for the given audio data and parameters.
Args:
audio_data: The raw audio data as a bytes object.
mime_type: Mime type of the audio data.
Returns:
A bytes object representing the WAV file header.
"""
parameters = self._parse_audio_mime_type(mime_type)
wav_buffer = io.BytesIO()
with wave.open(wav_buffer, "wb") as wf:
wf.setnchannels(1)
wf.setsampwidth(parameters["bits_per_sample"] // 8)
wf.setframerate(parameters["rate"])
wf.writeframes(audio_data)
return wav_buffer.getvalue()
def _parse_audio_mime_type(self, mime_type: str) -> dict[str, int]:
"""Parse bits per sample and rate from an audio MIME type string.
Assumes bits per sample is encoded like "L16" and rate as "rate=xxxxx".
Args:
mime_type: The audio MIME type string (e.g., "audio/L16;rate=24000").
Returns:
A dictionary with "bits_per_sample" and "rate" keys. Values will be
integers if found, otherwise None.
"""
if not mime_type.startswith("audio/L"):
_LOGGER.warning("Received unexpected MIME type %s", mime_type)
raise HomeAssistantError(f"Unsupported audio MIME type: {mime_type}")
bits_per_sample = 16
rate = 24000
# Extract rate from parameters
parts = mime_type.split(";")
for param in parts: # Skip the main type part
param = param.strip()
if param.lower().startswith("rate="):
# Handle cases like "rate=" with no value or non-integer value and keep rate as default
with suppress(ValueError, IndexError):
rate_str = param.split("=", 1)[1]
rate = int(rate_str)
elif param.startswith("audio/L"):
# Keep bits_per_sample as default if conversion fails
with suppress(ValueError, IndexError):
bits_per_sample = int(param.split("L", 1)[1])
return {"bits_per_sample": bits_per_sample, "rate": rate}

View File

@ -0,0 +1,413 @@
"""Tests for the Google Generative AI Conversation TTS entity."""
from __future__ import annotations
from collections.abc import Generator
from http import HTTPStatus
from pathlib import Path
from typing import Any
from unittest.mock import AsyncMock, MagicMock, Mock, patch
from google.genai import types
import pytest
from homeassistant.components import tts
from homeassistant.components.google_generative_ai_conversation.tts import (
ATTR_MODEL,
DOMAIN,
RECOMMENDED_TTS_MODEL,
)
from homeassistant.components.media_player import (
ATTR_MEDIA_CONTENT_ID,
DOMAIN as DOMAIN_MP,
SERVICE_PLAY_MEDIA,
)
from homeassistant.const import ATTR_ENTITY_ID, CONF_API_KEY, CONF_PLATFORM
from homeassistant.core import HomeAssistant, ServiceCall
from homeassistant.core_config import async_process_ha_core_config
from homeassistant.setup import async_setup_component
from . import API_ERROR_500
from tests.common import MockConfigEntry, async_mock_service
from tests.components.tts.common import retrieve_media
from tests.typing import ClientSessionGenerator
@pytest.fixture(autouse=True)
def tts_mutagen_mock_fixture_autouse(tts_mutagen_mock: MagicMock) -> None:
"""Mock writing tags."""
@pytest.fixture(autouse=True)
def mock_tts_cache_dir_autouse(mock_tts_cache_dir: Path) -> None:
"""Mock the TTS cache dir with empty dir."""
@pytest.fixture
async def calls(hass: HomeAssistant) -> list[ServiceCall]:
"""Mock media player calls."""
return async_mock_service(hass, DOMAIN_MP, SERVICE_PLAY_MEDIA)
@pytest.fixture(autouse=True)
async def setup_internal_url(hass: HomeAssistant) -> None:
"""Set up internal url."""
await async_process_ha_core_config(
hass, {"internal_url": "http://example.local:8123"}
)
@pytest.fixture
def mock_genai_client() -> Generator[AsyncMock]:
"""Mock genai_client."""
client = Mock()
client.aio.models.get = AsyncMock()
client.models.generate_content.return_value = types.GenerateContentResponse(
candidates=(
types.Candidate(
content=types.Content(
parts=(
types.Part(
inline_data=types.Blob(
data=b"raw-audio-bytes",
mime_type="audio/L16;rate=24000",
)
),
)
)
),
)
)
with patch(
"homeassistant.components.google_generative_ai_conversation.Client",
return_value=client,
) as mock_client:
yield mock_client
@pytest.fixture(name="setup")
async def setup_fixture(
hass: HomeAssistant,
config: dict[str, Any],
request: pytest.FixtureRequest,
mock_genai_client: AsyncMock,
) -> None:
"""Set up the test environment."""
if request.param == "mock_setup":
await mock_setup(hass, config)
if request.param == "mock_config_entry_setup":
await mock_config_entry_setup(hass, config)
else:
raise RuntimeError("Invalid setup fixture")
await hass.async_block_till_done()
@pytest.fixture(name="config")
def config_fixture() -> dict[str, Any]:
"""Return config."""
return {
CONF_API_KEY: "bla",
}
async def mock_setup(hass: HomeAssistant, config: dict[str, Any]) -> None:
"""Mock setup."""
assert await async_setup_component(
hass, tts.DOMAIN, {tts.DOMAIN: {CONF_PLATFORM: DOMAIN} | config}
)
async def mock_config_entry_setup(hass: HomeAssistant, config: dict[str, Any]) -> None:
"""Mock config entry setup."""
default_config = {tts.CONF_LANG: "en-US"}
config_entry = MockConfigEntry(domain=DOMAIN, data=default_config | config)
client_mock = Mock()
client_mock.models.get = None
client_mock.models.generate_content.return_value = types.GenerateContentResponse(
candidates=(
types.Candidate(
content=types.Content(
parts=(
types.Part(
inline_data=types.Blob(
data=b"raw-audio-bytes",
mime_type="audio/L16;rate=24000",
)
),
)
)
),
)
)
config_entry.runtime_data = client_mock
config_entry.add_to_hass(hass)
assert await hass.config_entries.async_setup(config_entry.entry_id)
@pytest.mark.parametrize(
("setup", "tts_service", "service_data"),
[
(
"mock_config_entry_setup",
"speak",
{
ATTR_ENTITY_ID: "tts.google_generative_ai_tts",
tts.ATTR_MEDIA_PLAYER_ENTITY_ID: "media_player.something",
tts.ATTR_MESSAGE: "There is a person at the front door.",
tts.ATTR_OPTIONS: {},
},
),
(
"mock_config_entry_setup",
"speak",
{
ATTR_ENTITY_ID: "tts.google_generative_ai_tts",
tts.ATTR_MEDIA_PLAYER_ENTITY_ID: "media_player.something",
tts.ATTR_MESSAGE: "There is a person at the front door.",
tts.ATTR_OPTIONS: {tts.ATTR_VOICE: "voice2"},
},
),
(
"mock_config_entry_setup",
"speak",
{
ATTR_ENTITY_ID: "tts.google_generative_ai_tts",
tts.ATTR_MEDIA_PLAYER_ENTITY_ID: "media_player.something",
tts.ATTR_MESSAGE: "There is a person at the front door.",
tts.ATTR_OPTIONS: {ATTR_MODEL: "model2"},
},
),
(
"mock_config_entry_setup",
"speak",
{
ATTR_ENTITY_ID: "tts.google_generative_ai_tts",
tts.ATTR_MEDIA_PLAYER_ENTITY_ID: "media_player.something",
tts.ATTR_MESSAGE: "There is a person at the front door.",
tts.ATTR_OPTIONS: {tts.ATTR_VOICE: "voice2", ATTR_MODEL: "model2"},
},
),
],
indirect=["setup"],
)
async def test_tts_service_speak(
setup: AsyncMock,
hass: HomeAssistant,
hass_client: ClientSessionGenerator,
calls: list[ServiceCall],
tts_service: str,
service_data: dict[str, Any],
) -> None:
"""Test tts service."""
tts_entity = hass.data[tts.DOMAIN].get_entity(service_data[ATTR_ENTITY_ID])
tts_entity._genai_client.models.generate_content.reset_mock()
await hass.services.async_call(
tts.DOMAIN,
tts_service,
service_data,
blocking=True,
)
assert len(calls) == 1
assert (
await retrieve_media(hass, hass_client, calls[0].data[ATTR_MEDIA_CONTENT_ID])
== HTTPStatus.OK
)
voice_id = service_data[tts.ATTR_OPTIONS].get(tts.ATTR_VOICE, "zephyr")
model_id = service_data[tts.ATTR_OPTIONS].get(ATTR_MODEL, RECOMMENDED_TTS_MODEL)
tts_entity._genai_client.models.generate_content.assert_called_once_with(
model=model_id,
contents="There is a person at the front door.",
config=types.GenerateContentConfig(
response_modalities=["AUDIO"],
speech_config=types.SpeechConfig(
voice_config=types.VoiceConfig(
prebuilt_voice_config=types.PrebuiltVoiceConfig(voice_name=voice_id)
)
),
),
)
@pytest.mark.parametrize(
("setup", "tts_service", "service_data"),
[
(
"mock_config_entry_setup",
"speak",
{
ATTR_ENTITY_ID: "tts.google_generative_ai_tts",
tts.ATTR_MEDIA_PLAYER_ENTITY_ID: "media_player.something",
tts.ATTR_MESSAGE: "There is a person at the front door.",
tts.ATTR_LANGUAGE: "de-DE",
tts.ATTR_OPTIONS: {tts.ATTR_VOICE: "voice1"},
},
),
(
"mock_config_entry_setup",
"speak",
{
ATTR_ENTITY_ID: "tts.google_generative_ai_tts",
tts.ATTR_MEDIA_PLAYER_ENTITY_ID: "media_player.something",
tts.ATTR_MESSAGE: "There is a person at the front door.",
tts.ATTR_LANGUAGE: "it-IT",
tts.ATTR_OPTIONS: {tts.ATTR_VOICE: "voice1"},
},
),
],
indirect=["setup"],
)
async def test_tts_service_speak_lang_config(
setup: AsyncMock,
hass: HomeAssistant,
hass_client: ClientSessionGenerator,
calls: list[ServiceCall],
tts_service: str,
service_data: dict[str, Any],
) -> None:
"""Test service call with languages in the config."""
tts_entity = hass.data[tts.DOMAIN].get_entity(service_data[ATTR_ENTITY_ID])
tts_entity._genai_client.models.generate_content.reset_mock()
await hass.services.async_call(
tts.DOMAIN,
tts_service,
service_data,
blocking=True,
)
assert len(calls) == 1
assert (
await retrieve_media(hass, hass_client, calls[0].data[ATTR_MEDIA_CONTENT_ID])
== HTTPStatus.OK
)
tts_entity._genai_client.models.generate_content.assert_called_once_with(
model=RECOMMENDED_TTS_MODEL,
contents="There is a person at the front door.",
config=types.GenerateContentConfig(
response_modalities=["AUDIO"],
speech_config=types.SpeechConfig(
voice_config=types.VoiceConfig(
prebuilt_voice_config=types.PrebuiltVoiceConfig(voice_name="voice1")
)
),
),
)
@pytest.mark.parametrize(
("setup", "tts_service", "service_data"),
[
(
"mock_config_entry_setup",
"speak",
{
ATTR_ENTITY_ID: "tts.google_generative_ai_tts",
tts.ATTR_MEDIA_PLAYER_ENTITY_ID: "media_player.something",
tts.ATTR_MESSAGE: "There is a person at the front door.",
tts.ATTR_OPTIONS: {tts.ATTR_VOICE: "voice1"},
},
),
],
indirect=["setup"],
)
async def test_tts_service_speak_error(
setup: AsyncMock,
hass: HomeAssistant,
hass_client: ClientSessionGenerator,
calls: list[ServiceCall],
tts_service: str,
service_data: dict[str, Any],
) -> None:
"""Test service call with HTTP response 500."""
tts_entity = hass.data[tts.DOMAIN].get_entity(service_data[ATTR_ENTITY_ID])
tts_entity._genai_client.models.generate_content.reset_mock()
tts_entity._genai_client.models.generate_content.side_effect = API_ERROR_500
await hass.services.async_call(
tts.DOMAIN,
tts_service,
service_data,
blocking=True,
)
assert len(calls) == 1
assert (
await retrieve_media(hass, hass_client, calls[0].data[ATTR_MEDIA_CONTENT_ID])
== HTTPStatus.INTERNAL_SERVER_ERROR
)
tts_entity._genai_client.models.generate_content.assert_called_once_with(
model=RECOMMENDED_TTS_MODEL,
contents="There is a person at the front door.",
config=types.GenerateContentConfig(
response_modalities=["AUDIO"],
speech_config=types.SpeechConfig(
voice_config=types.VoiceConfig(
prebuilt_voice_config=types.PrebuiltVoiceConfig(voice_name="voice1")
)
),
),
)
@pytest.mark.parametrize(
("setup", "tts_service", "service_data"),
[
(
"mock_config_entry_setup",
"speak",
{
ATTR_ENTITY_ID: "tts.google_generative_ai_tts",
tts.ATTR_MEDIA_PLAYER_ENTITY_ID: "media_player.something",
tts.ATTR_MESSAGE: "There is a person at the front door.",
tts.ATTR_OPTIONS: {},
},
),
],
indirect=["setup"],
)
async def test_tts_service_speak_without_options(
setup: AsyncMock,
hass: HomeAssistant,
hass_client: ClientSessionGenerator,
calls: list[ServiceCall],
tts_service: str,
service_data: dict[str, Any],
) -> None:
"""Test service call with HTTP response 200."""
tts_entity = hass.data[tts.DOMAIN].get_entity(service_data[ATTR_ENTITY_ID])
tts_entity._genai_client.models.generate_content.reset_mock()
await hass.services.async_call(
tts.DOMAIN,
tts_service,
service_data,
blocking=True,
)
assert len(calls) == 1
assert (
await retrieve_media(hass, hass_client, calls[0].data[ATTR_MEDIA_CONTENT_ID])
== HTTPStatus.OK
)
tts_entity._genai_client.models.generate_content.assert_called_once_with(
model=RECOMMENDED_TTS_MODEL,
contents="There is a person at the front door.",
config=types.GenerateContentConfig(
response_modalities=["AUDIO"],
speech_config=types.SpeechConfig(
voice_config=types.VoiceConfig(
prebuilt_voice_config=types.PrebuiltVoiceConfig(voice_name="zephyr")
)
),
),
)