mirror of
https://github.com/home-assistant/core.git
synced 2025-07-16 09:47:13 +00:00
Add support for Gemini's new TTS capabilities (#145872)
* Add support for Gemini TTS * Add tests * Use wave library and update a few comments
This commit is contained in:
parent
ec02f6d010
commit
c988d1ce36
@ -45,7 +45,10 @@ CONF_IMAGE_FILENAME = "image_filename"
|
||||
CONF_FILENAMES = "filenames"
|
||||
|
||||
CONFIG_SCHEMA = cv.config_entry_only_config_schema(DOMAIN)
|
||||
PLATFORMS = (Platform.CONVERSATION,)
|
||||
PLATFORMS = (
|
||||
Platform.CONVERSATION,
|
||||
Platform.TTS,
|
||||
)
|
||||
|
||||
type GoogleGenerativeAIConfigEntry = ConfigEntry[Client]
|
||||
|
||||
|
@ -6,9 +6,11 @@ DOMAIN = "google_generative_ai_conversation"
|
||||
LOGGER = logging.getLogger(__package__)
|
||||
CONF_PROMPT = "prompt"
|
||||
|
||||
ATTR_MODEL = "model"
|
||||
CONF_RECOMMENDED = "recommended"
|
||||
CONF_CHAT_MODEL = "chat_model"
|
||||
RECOMMENDED_CHAT_MODEL = "models/gemini-2.0-flash"
|
||||
RECOMMENDED_TTS_MODEL = "gemini-2.5-flash-preview-tts"
|
||||
CONF_TEMPERATURE = "temperature"
|
||||
RECOMMENDED_TEMPERATURE = 1.0
|
||||
CONF_TOP_P = "top_p"
|
||||
|
@ -0,0 +1,216 @@
|
||||
"""Text to speech support for Google Generative AI."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from contextlib import suppress
|
||||
import io
|
||||
import logging
|
||||
from typing import Any
|
||||
import wave
|
||||
|
||||
from google.genai import types
|
||||
|
||||
from homeassistant.components.tts import (
|
||||
ATTR_VOICE,
|
||||
TextToSpeechEntity,
|
||||
TtsAudioType,
|
||||
Voice,
|
||||
)
|
||||
from homeassistant.config_entries import ConfigEntry
|
||||
from homeassistant.core import HomeAssistant, callback
|
||||
from homeassistant.exceptions import HomeAssistantError
|
||||
from homeassistant.helpers import device_registry as dr
|
||||
from homeassistant.helpers.entity_platform import AddConfigEntryEntitiesCallback
|
||||
|
||||
from .const import ATTR_MODEL, DOMAIN, RECOMMENDED_TTS_MODEL
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def async_setup_entry(
|
||||
hass: HomeAssistant,
|
||||
config_entry: ConfigEntry,
|
||||
async_add_entities: AddConfigEntryEntitiesCallback,
|
||||
) -> None:
|
||||
"""Set up TTS entity."""
|
||||
tts_entity = GoogleGenerativeAITextToSpeechEntity(config_entry)
|
||||
async_add_entities([tts_entity])
|
||||
|
||||
|
||||
class GoogleGenerativeAITextToSpeechEntity(TextToSpeechEntity):
|
||||
"""Google Generative AI text-to-speech entity."""
|
||||
|
||||
_attr_supported_options = [ATTR_VOICE, ATTR_MODEL]
|
||||
# See https://ai.google.dev/gemini-api/docs/speech-generation#languages
|
||||
_attr_supported_languages = [
|
||||
"ar-EG",
|
||||
"bn-BD",
|
||||
"de-DE",
|
||||
"en-IN",
|
||||
"en-US",
|
||||
"es-US",
|
||||
"fr-FR",
|
||||
"hi-IN",
|
||||
"id-ID",
|
||||
"it-IT",
|
||||
"ja-JP",
|
||||
"ko-KR",
|
||||
"mr-IN",
|
||||
"nl-NL",
|
||||
"pl-PL",
|
||||
"pt-BR",
|
||||
"ro-RO",
|
||||
"ru-RU",
|
||||
"ta-IN",
|
||||
"te-IN",
|
||||
"th-TH",
|
||||
"tr-TR",
|
||||
"uk-UA",
|
||||
"vi-VN",
|
||||
]
|
||||
_attr_default_language = "en-US"
|
||||
# See https://ai.google.dev/gemini-api/docs/speech-generation#voices
|
||||
_supported_voices = [
|
||||
Voice(voice.split(" ", 1)[0].lower(), voice)
|
||||
for voice in (
|
||||
"Zephyr (Bright)",
|
||||
"Puck (Upbeat)",
|
||||
"Charon (Informative)",
|
||||
"Kore (Firm)",
|
||||
"Fenrir (Excitable)",
|
||||
"Leda (Youthful)",
|
||||
"Orus (Firm)",
|
||||
"Aoede (Breezy)",
|
||||
"Callirrhoe (Easy-going)",
|
||||
"Autonoe (Bright)",
|
||||
"Enceladus (Breathy)",
|
||||
"Iapetus (Clear)",
|
||||
"Umbriel (Easy-going)",
|
||||
"Algieba (Smooth)",
|
||||
"Despina (Smooth)",
|
||||
"Erinome (Clear)",
|
||||
"Algenib (Gravelly)",
|
||||
"Rasalgethi (Informative)",
|
||||
"Laomedeia (Upbeat)",
|
||||
"Achernar (Soft)",
|
||||
"Alnilam (Firm)",
|
||||
"Schedar (Even)",
|
||||
"Gacrux (Mature)",
|
||||
"Pulcherrima (Forward)",
|
||||
"Achird (Friendly)",
|
||||
"Zubenelgenubi (Casual)",
|
||||
"Vindemiatrix (Gentle)",
|
||||
"Sadachbia (Lively)",
|
||||
"Sadaltager (Knowledgeable)",
|
||||
"Sulafat (Warm)",
|
||||
)
|
||||
]
|
||||
|
||||
def __init__(self, entry: ConfigEntry) -> None:
|
||||
"""Initialize Google Generative AI Conversation speech entity."""
|
||||
self.entry = entry
|
||||
self._attr_name = "Google Generative AI TTS"
|
||||
self._attr_unique_id = f"{entry.entry_id}_tts"
|
||||
self._attr_device_info = dr.DeviceInfo(
|
||||
identifiers={(DOMAIN, entry.entry_id)},
|
||||
name=entry.title,
|
||||
manufacturer="Google",
|
||||
model="Generative AI",
|
||||
entry_type=dr.DeviceEntryType.SERVICE,
|
||||
)
|
||||
self._genai_client = entry.runtime_data
|
||||
self._default_voice_id = self._supported_voices[0].voice_id
|
||||
|
||||
@callback
|
||||
def async_get_supported_voices(self, language: str) -> list[Voice] | None:
|
||||
"""Return a list of supported voices for a language."""
|
||||
return self._supported_voices
|
||||
|
||||
async def async_get_tts_audio(
|
||||
self, message: str, language: str, options: dict[str, Any]
|
||||
) -> TtsAudioType:
|
||||
"""Load tts audio file from the engine."""
|
||||
try:
|
||||
response = self._genai_client.models.generate_content(
|
||||
model=options.get(ATTR_MODEL, RECOMMENDED_TTS_MODEL),
|
||||
contents=message,
|
||||
config=types.GenerateContentConfig(
|
||||
response_modalities=["AUDIO"],
|
||||
speech_config=types.SpeechConfig(
|
||||
voice_config=types.VoiceConfig(
|
||||
prebuilt_voice_config=types.PrebuiltVoiceConfig(
|
||||
voice_name=options.get(
|
||||
ATTR_VOICE, self._default_voice_id
|
||||
)
|
||||
)
|
||||
)
|
||||
),
|
||||
),
|
||||
)
|
||||
|
||||
data = response.candidates[0].content.parts[0].inline_data.data
|
||||
mime_type = response.candidates[0].content.parts[0].inline_data.mime_type
|
||||
except Exception as exc:
|
||||
_LOGGER.warning(
|
||||
"Error during processing of TTS request %s", exc, exc_info=True
|
||||
)
|
||||
raise HomeAssistantError(exc) from exc
|
||||
return "wav", self._convert_to_wav(data, mime_type)
|
||||
|
||||
def _convert_to_wav(self, audio_data: bytes, mime_type: str) -> bytes:
|
||||
"""Generate a WAV file header for the given audio data and parameters.
|
||||
|
||||
Args:
|
||||
audio_data: The raw audio data as a bytes object.
|
||||
mime_type: Mime type of the audio data.
|
||||
|
||||
Returns:
|
||||
A bytes object representing the WAV file header.
|
||||
|
||||
"""
|
||||
parameters = self._parse_audio_mime_type(mime_type)
|
||||
|
||||
wav_buffer = io.BytesIO()
|
||||
with wave.open(wav_buffer, "wb") as wf:
|
||||
wf.setnchannels(1)
|
||||
wf.setsampwidth(parameters["bits_per_sample"] // 8)
|
||||
wf.setframerate(parameters["rate"])
|
||||
wf.writeframes(audio_data)
|
||||
|
||||
return wav_buffer.getvalue()
|
||||
|
||||
def _parse_audio_mime_type(self, mime_type: str) -> dict[str, int]:
|
||||
"""Parse bits per sample and rate from an audio MIME type string.
|
||||
|
||||
Assumes bits per sample is encoded like "L16" and rate as "rate=xxxxx".
|
||||
|
||||
Args:
|
||||
mime_type: The audio MIME type string (e.g., "audio/L16;rate=24000").
|
||||
|
||||
Returns:
|
||||
A dictionary with "bits_per_sample" and "rate" keys. Values will be
|
||||
integers if found, otherwise None.
|
||||
|
||||
"""
|
||||
if not mime_type.startswith("audio/L"):
|
||||
_LOGGER.warning("Received unexpected MIME type %s", mime_type)
|
||||
raise HomeAssistantError(f"Unsupported audio MIME type: {mime_type}")
|
||||
|
||||
bits_per_sample = 16
|
||||
rate = 24000
|
||||
|
||||
# Extract rate from parameters
|
||||
parts = mime_type.split(";")
|
||||
for param in parts: # Skip the main type part
|
||||
param = param.strip()
|
||||
if param.lower().startswith("rate="):
|
||||
# Handle cases like "rate=" with no value or non-integer value and keep rate as default
|
||||
with suppress(ValueError, IndexError):
|
||||
rate_str = param.split("=", 1)[1]
|
||||
rate = int(rate_str)
|
||||
elif param.startswith("audio/L"):
|
||||
# Keep bits_per_sample as default if conversion fails
|
||||
with suppress(ValueError, IndexError):
|
||||
bits_per_sample = int(param.split("L", 1)[1])
|
||||
|
||||
return {"bits_per_sample": bits_per_sample, "rate": rate}
|
413
tests/components/google_generative_ai_conversation/test_tts.py
Normal file
413
tests/components/google_generative_ai_conversation/test_tts.py
Normal file
@ -0,0 +1,413 @@
|
||||
"""Tests for the Google Generative AI Conversation TTS entity."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Generator
|
||||
from http import HTTPStatus
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
from unittest.mock import AsyncMock, MagicMock, Mock, patch
|
||||
|
||||
from google.genai import types
|
||||
import pytest
|
||||
|
||||
from homeassistant.components import tts
|
||||
from homeassistant.components.google_generative_ai_conversation.tts import (
|
||||
ATTR_MODEL,
|
||||
DOMAIN,
|
||||
RECOMMENDED_TTS_MODEL,
|
||||
)
|
||||
from homeassistant.components.media_player import (
|
||||
ATTR_MEDIA_CONTENT_ID,
|
||||
DOMAIN as DOMAIN_MP,
|
||||
SERVICE_PLAY_MEDIA,
|
||||
)
|
||||
from homeassistant.const import ATTR_ENTITY_ID, CONF_API_KEY, CONF_PLATFORM
|
||||
from homeassistant.core import HomeAssistant, ServiceCall
|
||||
from homeassistant.core_config import async_process_ha_core_config
|
||||
from homeassistant.setup import async_setup_component
|
||||
|
||||
from . import API_ERROR_500
|
||||
|
||||
from tests.common import MockConfigEntry, async_mock_service
|
||||
from tests.components.tts.common import retrieve_media
|
||||
from tests.typing import ClientSessionGenerator
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def tts_mutagen_mock_fixture_autouse(tts_mutagen_mock: MagicMock) -> None:
|
||||
"""Mock writing tags."""
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def mock_tts_cache_dir_autouse(mock_tts_cache_dir: Path) -> None:
|
||||
"""Mock the TTS cache dir with empty dir."""
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def calls(hass: HomeAssistant) -> list[ServiceCall]:
|
||||
"""Mock media player calls."""
|
||||
return async_mock_service(hass, DOMAIN_MP, SERVICE_PLAY_MEDIA)
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
async def setup_internal_url(hass: HomeAssistant) -> None:
|
||||
"""Set up internal url."""
|
||||
await async_process_ha_core_config(
|
||||
hass, {"internal_url": "http://example.local:8123"}
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_genai_client() -> Generator[AsyncMock]:
|
||||
"""Mock genai_client."""
|
||||
client = Mock()
|
||||
client.aio.models.get = AsyncMock()
|
||||
client.models.generate_content.return_value = types.GenerateContentResponse(
|
||||
candidates=(
|
||||
types.Candidate(
|
||||
content=types.Content(
|
||||
parts=(
|
||||
types.Part(
|
||||
inline_data=types.Blob(
|
||||
data=b"raw-audio-bytes",
|
||||
mime_type="audio/L16;rate=24000",
|
||||
)
|
||||
),
|
||||
)
|
||||
)
|
||||
),
|
||||
)
|
||||
)
|
||||
with patch(
|
||||
"homeassistant.components.google_generative_ai_conversation.Client",
|
||||
return_value=client,
|
||||
) as mock_client:
|
||||
yield mock_client
|
||||
|
||||
|
||||
@pytest.fixture(name="setup")
|
||||
async def setup_fixture(
|
||||
hass: HomeAssistant,
|
||||
config: dict[str, Any],
|
||||
request: pytest.FixtureRequest,
|
||||
mock_genai_client: AsyncMock,
|
||||
) -> None:
|
||||
"""Set up the test environment."""
|
||||
if request.param == "mock_setup":
|
||||
await mock_setup(hass, config)
|
||||
if request.param == "mock_config_entry_setup":
|
||||
await mock_config_entry_setup(hass, config)
|
||||
else:
|
||||
raise RuntimeError("Invalid setup fixture")
|
||||
|
||||
await hass.async_block_till_done()
|
||||
|
||||
|
||||
@pytest.fixture(name="config")
|
||||
def config_fixture() -> dict[str, Any]:
|
||||
"""Return config."""
|
||||
return {
|
||||
CONF_API_KEY: "bla",
|
||||
}
|
||||
|
||||
|
||||
async def mock_setup(hass: HomeAssistant, config: dict[str, Any]) -> None:
|
||||
"""Mock setup."""
|
||||
assert await async_setup_component(
|
||||
hass, tts.DOMAIN, {tts.DOMAIN: {CONF_PLATFORM: DOMAIN} | config}
|
||||
)
|
||||
|
||||
|
||||
async def mock_config_entry_setup(hass: HomeAssistant, config: dict[str, Any]) -> None:
|
||||
"""Mock config entry setup."""
|
||||
default_config = {tts.CONF_LANG: "en-US"}
|
||||
config_entry = MockConfigEntry(domain=DOMAIN, data=default_config | config)
|
||||
|
||||
client_mock = Mock()
|
||||
client_mock.models.get = None
|
||||
client_mock.models.generate_content.return_value = types.GenerateContentResponse(
|
||||
candidates=(
|
||||
types.Candidate(
|
||||
content=types.Content(
|
||||
parts=(
|
||||
types.Part(
|
||||
inline_data=types.Blob(
|
||||
data=b"raw-audio-bytes",
|
||||
mime_type="audio/L16;rate=24000",
|
||||
)
|
||||
),
|
||||
)
|
||||
)
|
||||
),
|
||||
)
|
||||
)
|
||||
config_entry.runtime_data = client_mock
|
||||
config_entry.add_to_hass(hass)
|
||||
|
||||
assert await hass.config_entries.async_setup(config_entry.entry_id)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("setup", "tts_service", "service_data"),
|
||||
[
|
||||
(
|
||||
"mock_config_entry_setup",
|
||||
"speak",
|
||||
{
|
||||
ATTR_ENTITY_ID: "tts.google_generative_ai_tts",
|
||||
tts.ATTR_MEDIA_PLAYER_ENTITY_ID: "media_player.something",
|
||||
tts.ATTR_MESSAGE: "There is a person at the front door.",
|
||||
tts.ATTR_OPTIONS: {},
|
||||
},
|
||||
),
|
||||
(
|
||||
"mock_config_entry_setup",
|
||||
"speak",
|
||||
{
|
||||
ATTR_ENTITY_ID: "tts.google_generative_ai_tts",
|
||||
tts.ATTR_MEDIA_PLAYER_ENTITY_ID: "media_player.something",
|
||||
tts.ATTR_MESSAGE: "There is a person at the front door.",
|
||||
tts.ATTR_OPTIONS: {tts.ATTR_VOICE: "voice2"},
|
||||
},
|
||||
),
|
||||
(
|
||||
"mock_config_entry_setup",
|
||||
"speak",
|
||||
{
|
||||
ATTR_ENTITY_ID: "tts.google_generative_ai_tts",
|
||||
tts.ATTR_MEDIA_PLAYER_ENTITY_ID: "media_player.something",
|
||||
tts.ATTR_MESSAGE: "There is a person at the front door.",
|
||||
tts.ATTR_OPTIONS: {ATTR_MODEL: "model2"},
|
||||
},
|
||||
),
|
||||
(
|
||||
"mock_config_entry_setup",
|
||||
"speak",
|
||||
{
|
||||
ATTR_ENTITY_ID: "tts.google_generative_ai_tts",
|
||||
tts.ATTR_MEDIA_PLAYER_ENTITY_ID: "media_player.something",
|
||||
tts.ATTR_MESSAGE: "There is a person at the front door.",
|
||||
tts.ATTR_OPTIONS: {tts.ATTR_VOICE: "voice2", ATTR_MODEL: "model2"},
|
||||
},
|
||||
),
|
||||
],
|
||||
indirect=["setup"],
|
||||
)
|
||||
async def test_tts_service_speak(
|
||||
setup: AsyncMock,
|
||||
hass: HomeAssistant,
|
||||
hass_client: ClientSessionGenerator,
|
||||
calls: list[ServiceCall],
|
||||
tts_service: str,
|
||||
service_data: dict[str, Any],
|
||||
) -> None:
|
||||
"""Test tts service."""
|
||||
tts_entity = hass.data[tts.DOMAIN].get_entity(service_data[ATTR_ENTITY_ID])
|
||||
tts_entity._genai_client.models.generate_content.reset_mock()
|
||||
|
||||
await hass.services.async_call(
|
||||
tts.DOMAIN,
|
||||
tts_service,
|
||||
service_data,
|
||||
blocking=True,
|
||||
)
|
||||
|
||||
assert len(calls) == 1
|
||||
assert (
|
||||
await retrieve_media(hass, hass_client, calls[0].data[ATTR_MEDIA_CONTENT_ID])
|
||||
== HTTPStatus.OK
|
||||
)
|
||||
voice_id = service_data[tts.ATTR_OPTIONS].get(tts.ATTR_VOICE, "zephyr")
|
||||
model_id = service_data[tts.ATTR_OPTIONS].get(ATTR_MODEL, RECOMMENDED_TTS_MODEL)
|
||||
|
||||
tts_entity._genai_client.models.generate_content.assert_called_once_with(
|
||||
model=model_id,
|
||||
contents="There is a person at the front door.",
|
||||
config=types.GenerateContentConfig(
|
||||
response_modalities=["AUDIO"],
|
||||
speech_config=types.SpeechConfig(
|
||||
voice_config=types.VoiceConfig(
|
||||
prebuilt_voice_config=types.PrebuiltVoiceConfig(voice_name=voice_id)
|
||||
)
|
||||
),
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("setup", "tts_service", "service_data"),
|
||||
[
|
||||
(
|
||||
"mock_config_entry_setup",
|
||||
"speak",
|
||||
{
|
||||
ATTR_ENTITY_ID: "tts.google_generative_ai_tts",
|
||||
tts.ATTR_MEDIA_PLAYER_ENTITY_ID: "media_player.something",
|
||||
tts.ATTR_MESSAGE: "There is a person at the front door.",
|
||||
tts.ATTR_LANGUAGE: "de-DE",
|
||||
tts.ATTR_OPTIONS: {tts.ATTR_VOICE: "voice1"},
|
||||
},
|
||||
),
|
||||
(
|
||||
"mock_config_entry_setup",
|
||||
"speak",
|
||||
{
|
||||
ATTR_ENTITY_ID: "tts.google_generative_ai_tts",
|
||||
tts.ATTR_MEDIA_PLAYER_ENTITY_ID: "media_player.something",
|
||||
tts.ATTR_MESSAGE: "There is a person at the front door.",
|
||||
tts.ATTR_LANGUAGE: "it-IT",
|
||||
tts.ATTR_OPTIONS: {tts.ATTR_VOICE: "voice1"},
|
||||
},
|
||||
),
|
||||
],
|
||||
indirect=["setup"],
|
||||
)
|
||||
async def test_tts_service_speak_lang_config(
|
||||
setup: AsyncMock,
|
||||
hass: HomeAssistant,
|
||||
hass_client: ClientSessionGenerator,
|
||||
calls: list[ServiceCall],
|
||||
tts_service: str,
|
||||
service_data: dict[str, Any],
|
||||
) -> None:
|
||||
"""Test service call with languages in the config."""
|
||||
tts_entity = hass.data[tts.DOMAIN].get_entity(service_data[ATTR_ENTITY_ID])
|
||||
tts_entity._genai_client.models.generate_content.reset_mock()
|
||||
|
||||
await hass.services.async_call(
|
||||
tts.DOMAIN,
|
||||
tts_service,
|
||||
service_data,
|
||||
blocking=True,
|
||||
)
|
||||
|
||||
assert len(calls) == 1
|
||||
assert (
|
||||
await retrieve_media(hass, hass_client, calls[0].data[ATTR_MEDIA_CONTENT_ID])
|
||||
== HTTPStatus.OK
|
||||
)
|
||||
|
||||
tts_entity._genai_client.models.generate_content.assert_called_once_with(
|
||||
model=RECOMMENDED_TTS_MODEL,
|
||||
contents="There is a person at the front door.",
|
||||
config=types.GenerateContentConfig(
|
||||
response_modalities=["AUDIO"],
|
||||
speech_config=types.SpeechConfig(
|
||||
voice_config=types.VoiceConfig(
|
||||
prebuilt_voice_config=types.PrebuiltVoiceConfig(voice_name="voice1")
|
||||
)
|
||||
),
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("setup", "tts_service", "service_data"),
|
||||
[
|
||||
(
|
||||
"mock_config_entry_setup",
|
||||
"speak",
|
||||
{
|
||||
ATTR_ENTITY_ID: "tts.google_generative_ai_tts",
|
||||
tts.ATTR_MEDIA_PLAYER_ENTITY_ID: "media_player.something",
|
||||
tts.ATTR_MESSAGE: "There is a person at the front door.",
|
||||
tts.ATTR_OPTIONS: {tts.ATTR_VOICE: "voice1"},
|
||||
},
|
||||
),
|
||||
],
|
||||
indirect=["setup"],
|
||||
)
|
||||
async def test_tts_service_speak_error(
|
||||
setup: AsyncMock,
|
||||
hass: HomeAssistant,
|
||||
hass_client: ClientSessionGenerator,
|
||||
calls: list[ServiceCall],
|
||||
tts_service: str,
|
||||
service_data: dict[str, Any],
|
||||
) -> None:
|
||||
"""Test service call with HTTP response 500."""
|
||||
tts_entity = hass.data[tts.DOMAIN].get_entity(service_data[ATTR_ENTITY_ID])
|
||||
tts_entity._genai_client.models.generate_content.reset_mock()
|
||||
tts_entity._genai_client.models.generate_content.side_effect = API_ERROR_500
|
||||
|
||||
await hass.services.async_call(
|
||||
tts.DOMAIN,
|
||||
tts_service,
|
||||
service_data,
|
||||
blocking=True,
|
||||
)
|
||||
|
||||
assert len(calls) == 1
|
||||
assert (
|
||||
await retrieve_media(hass, hass_client, calls[0].data[ATTR_MEDIA_CONTENT_ID])
|
||||
== HTTPStatus.INTERNAL_SERVER_ERROR
|
||||
)
|
||||
|
||||
tts_entity._genai_client.models.generate_content.assert_called_once_with(
|
||||
model=RECOMMENDED_TTS_MODEL,
|
||||
contents="There is a person at the front door.",
|
||||
config=types.GenerateContentConfig(
|
||||
response_modalities=["AUDIO"],
|
||||
speech_config=types.SpeechConfig(
|
||||
voice_config=types.VoiceConfig(
|
||||
prebuilt_voice_config=types.PrebuiltVoiceConfig(voice_name="voice1")
|
||||
)
|
||||
),
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("setup", "tts_service", "service_data"),
|
||||
[
|
||||
(
|
||||
"mock_config_entry_setup",
|
||||
"speak",
|
||||
{
|
||||
ATTR_ENTITY_ID: "tts.google_generative_ai_tts",
|
||||
tts.ATTR_MEDIA_PLAYER_ENTITY_ID: "media_player.something",
|
||||
tts.ATTR_MESSAGE: "There is a person at the front door.",
|
||||
tts.ATTR_OPTIONS: {},
|
||||
},
|
||||
),
|
||||
],
|
||||
indirect=["setup"],
|
||||
)
|
||||
async def test_tts_service_speak_without_options(
|
||||
setup: AsyncMock,
|
||||
hass: HomeAssistant,
|
||||
hass_client: ClientSessionGenerator,
|
||||
calls: list[ServiceCall],
|
||||
tts_service: str,
|
||||
service_data: dict[str, Any],
|
||||
) -> None:
|
||||
"""Test service call with HTTP response 200."""
|
||||
tts_entity = hass.data[tts.DOMAIN].get_entity(service_data[ATTR_ENTITY_ID])
|
||||
tts_entity._genai_client.models.generate_content.reset_mock()
|
||||
|
||||
await hass.services.async_call(
|
||||
tts.DOMAIN,
|
||||
tts_service,
|
||||
service_data,
|
||||
blocking=True,
|
||||
)
|
||||
|
||||
assert len(calls) == 1
|
||||
assert (
|
||||
await retrieve_media(hass, hass_client, calls[0].data[ATTR_MEDIA_CONTENT_ID])
|
||||
== HTTPStatus.OK
|
||||
)
|
||||
|
||||
tts_entity._genai_client.models.generate_content.assert_called_once_with(
|
||||
model=RECOMMENDED_TTS_MODEL,
|
||||
contents="There is a person at the front door.",
|
||||
config=types.GenerateContentConfig(
|
||||
response_modalities=["AUDIO"],
|
||||
speech_config=types.SpeechConfig(
|
||||
voice_config=types.VoiceConfig(
|
||||
prebuilt_voice_config=types.PrebuiltVoiceConfig(voice_name="zephyr")
|
||||
)
|
||||
),
|
||||
),
|
||||
)
|
Loading…
x
Reference in New Issue
Block a user