Add Google AI STT (#147563)

This commit is contained in:
tronikos 2025-07-16 05:11:29 -07:00 committed by GitHub
parent 26a9af7371
commit 02a11638b3
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
11 changed files with 897 additions and 74 deletions

View File

@ -36,12 +36,14 @@ from homeassistant.helpers.typing import ConfigType
from .const import (
CONF_PROMPT,
DEFAULT_AI_TASK_NAME,
DEFAULT_STT_NAME,
DEFAULT_TITLE,
DEFAULT_TTS_NAME,
DOMAIN,
LOGGER,
RECOMMENDED_AI_TASK_OPTIONS,
RECOMMENDED_CHAT_MODEL,
RECOMMENDED_STT_OPTIONS,
RECOMMENDED_TTS_OPTIONS,
TIMEOUT_MILLIS,
)
@ -55,6 +57,7 @@ CONFIG_SCHEMA = cv.config_entry_only_config_schema(DOMAIN)
PLATFORMS = (
Platform.AI_TASK,
Platform.CONVERSATION,
Platform.STT,
Platform.TTS,
)
@ -301,7 +304,7 @@ async def async_migrate_integration(hass: HomeAssistant) -> None:
if not use_existing:
await hass.config_entries.async_remove(entry.entry_id)
else:
_add_ai_task_subentry(hass, entry)
_add_ai_task_and_stt_subentries(hass, entry)
hass.config_entries.async_update_entry(
entry,
title=DEFAULT_TITLE,
@ -350,8 +353,7 @@ async def async_migrate_entry(
hass.config_entries.async_update_entry(entry, minor_version=2)
if entry.version == 2 and entry.minor_version == 2:
# Add AI Task subentry with default options
_add_ai_task_subentry(hass, entry)
_add_ai_task_and_stt_subentries(hass, entry)
hass.config_entries.async_update_entry(entry, minor_version=3)
if entry.version == 2 and entry.minor_version == 3:
@ -393,10 +395,10 @@ async def async_migrate_entry(
return True
def _add_ai_task_subentry(
def _add_ai_task_and_stt_subentries(
hass: HomeAssistant, entry: GoogleGenerativeAIConfigEntry
) -> None:
"""Add AI Task subentry to the config entry."""
"""Add AI Task and STT subentries to the config entry."""
hass.config_entries.async_add_subentry(
entry,
ConfigSubentry(
@ -406,3 +408,12 @@ def _add_ai_task_subentry(
unique_id=None,
),
)
hass.config_entries.async_add_subentry(
entry,
ConfigSubentry(
data=MappingProxyType(RECOMMENDED_STT_OPTIONS),
subentry_type="stt",
title=DEFAULT_STT_NAME,
unique_id=None,
),
)

View File

@ -49,6 +49,8 @@ from .const import (
CONF_USE_GOOGLE_SEARCH_TOOL,
DEFAULT_AI_TASK_NAME,
DEFAULT_CONVERSATION_NAME,
DEFAULT_STT_NAME,
DEFAULT_STT_PROMPT,
DEFAULT_TITLE,
DEFAULT_TTS_NAME,
DOMAIN,
@ -57,6 +59,8 @@ from .const import (
RECOMMENDED_CONVERSATION_OPTIONS,
RECOMMENDED_HARM_BLOCK_THRESHOLD,
RECOMMENDED_MAX_TOKENS,
RECOMMENDED_STT_MODEL,
RECOMMENDED_STT_OPTIONS,
RECOMMENDED_TEMPERATURE,
RECOMMENDED_TOP_K,
RECOMMENDED_TOP_P,
@ -144,6 +148,12 @@ class GoogleGenerativeAIConfigFlow(ConfigFlow, domain=DOMAIN):
"title": DEFAULT_AI_TASK_NAME,
"unique_id": None,
},
{
"subentry_type": "stt",
"data": RECOMMENDED_STT_OPTIONS,
"title": DEFAULT_STT_NAME,
"unique_id": None,
},
],
)
return self.async_show_form(
@ -191,6 +201,7 @@ class GoogleGenerativeAIConfigFlow(ConfigFlow, domain=DOMAIN):
"""Return subentries supported by this integration."""
return {
"conversation": LLMSubentryFlowHandler,
"stt": LLMSubentryFlowHandler,
"tts": LLMSubentryFlowHandler,
"ai_task_data": LLMSubentryFlowHandler,
}
@ -228,6 +239,8 @@ class LLMSubentryFlowHandler(ConfigSubentryFlow):
options = RECOMMENDED_TTS_OPTIONS.copy()
elif self._subentry_type == "ai_task_data":
options = RECOMMENDED_AI_TASK_OPTIONS.copy()
elif self._subentry_type == "stt":
options = RECOMMENDED_STT_OPTIONS.copy()
else:
options = RECOMMENDED_CONVERSATION_OPTIONS.copy()
else:
@ -304,6 +317,8 @@ async def google_generative_ai_config_option_schema(
default_name = DEFAULT_TTS_NAME
elif subentry_type == "ai_task_data":
default_name = DEFAULT_AI_TASK_NAME
elif subentry_type == "stt":
default_name = DEFAULT_STT_NAME
else:
default_name = DEFAULT_CONVERSATION_NAME
schema: dict[vol.Required | vol.Optional, Any] = {
@ -331,6 +346,17 @@ async def google_generative_ai_config_option_schema(
),
}
)
elif subentry_type == "stt":
schema.update(
{
vol.Optional(
CONF_PROMPT,
description={
"suggested_value": options.get(CONF_PROMPT, DEFAULT_STT_PROMPT)
},
): TemplateSelector(),
}
)
schema.update(
{
@ -388,6 +414,8 @@ async def google_generative_ai_config_option_schema(
if subentry_type == "tts":
default_model = RECOMMENDED_TTS_MODEL
elif subentry_type == "stt":
default_model = RECOMMENDED_STT_MODEL
else:
default_model = RECOMMENDED_CHAT_MODEL

View File

@ -5,18 +5,23 @@ import logging
from homeassistant.const import CONF_LLM_HASS_API
from homeassistant.helpers import llm
LOGGER = logging.getLogger(__package__)
DOMAIN = "google_generative_ai_conversation"
DEFAULT_TITLE = "Google Generative AI"
LOGGER = logging.getLogger(__package__)
CONF_PROMPT = "prompt"
DEFAULT_CONVERSATION_NAME = "Google AI Conversation"
DEFAULT_STT_NAME = "Google AI STT"
DEFAULT_TTS_NAME = "Google AI TTS"
DEFAULT_AI_TASK_NAME = "Google AI Task"
CONF_PROMPT = "prompt"
DEFAULT_STT_PROMPT = "Transcribe the attached audio"
CONF_RECOMMENDED = "recommended"
CONF_CHAT_MODEL = "chat_model"
RECOMMENDED_CHAT_MODEL = "models/gemini-2.5-flash"
RECOMMENDED_STT_MODEL = RECOMMENDED_CHAT_MODEL
RECOMMENDED_TTS_MODEL = "models/gemini-2.5-flash-preview-tts"
CONF_TEMPERATURE = "temperature"
RECOMMENDED_TEMPERATURE = 1.0
@ -43,6 +48,11 @@ RECOMMENDED_CONVERSATION_OPTIONS = {
CONF_RECOMMENDED: True,
}
RECOMMENDED_STT_OPTIONS = {
CONF_PROMPT: DEFAULT_STT_PROMPT,
CONF_RECOMMENDED: True,
}
RECOMMENDED_TTS_OPTIONS = {
CONF_RECOMMENDED: True,
}

View File

@ -61,6 +61,38 @@
"invalid_google_search_option": "Google Search can only be enabled if nothing is selected in the \"Control Home Assistant\" setting."
}
},
"stt": {
"initiate_flow": {
"user": "Add Speech-to-Text service",
"reconfigure": "Reconfigure Speech-to-Text service"
},
"entry_type": "Speech-to-Text",
"step": {
"set_options": {
"data": {
"name": "[%key:common::config_flow::data::name%]",
"recommended": "[%key:component::google_generative_ai_conversation::config_subentries::conversation::step::set_options::data::recommended%]",
"prompt": "Instructions",
"chat_model": "[%key:common::generic::model%]",
"temperature": "[%key:component::google_generative_ai_conversation::config_subentries::conversation::step::set_options::data::temperature%]",
"top_p": "[%key:component::google_generative_ai_conversation::config_subentries::conversation::step::set_options::data::top_p%]",
"top_k": "[%key:component::google_generative_ai_conversation::config_subentries::conversation::step::set_options::data::top_k%]",
"max_tokens": "[%key:component::google_generative_ai_conversation::config_subentries::conversation::step::set_options::data::max_tokens%]",
"harassment_block_threshold": "[%key:component::google_generative_ai_conversation::config_subentries::conversation::step::set_options::data::harassment_block_threshold%]",
"hate_block_threshold": "[%key:component::google_generative_ai_conversation::config_subentries::conversation::step::set_options::data::hate_block_threshold%]",
"sexual_block_threshold": "[%key:component::google_generative_ai_conversation::config_subentries::conversation::step::set_options::data::sexual_block_threshold%]",
"dangerous_block_threshold": "[%key:component::google_generative_ai_conversation::config_subentries::conversation::step::set_options::data::dangerous_block_threshold%]"
},
"data_description": {
"prompt": "Instruct how the LLM should transcribe the audio."
}
}
},
"abort": {
"entry_not_loaded": "[%key:component::google_generative_ai_conversation::config_subentries::conversation::abort::entry_not_loaded%]",
"reconfigure_successful": "[%key:common::config_flow::abort::reconfigure_successful%]"
}
},
"tts": {
"initiate_flow": {
"user": "Add Text-to-Speech service",

View File

@ -0,0 +1,254 @@
"""Speech to text support for Google Generative AI."""
from __future__ import annotations
from collections.abc import AsyncIterable
from google.genai.errors import APIError, ClientError
from google.genai.types import Part
from homeassistant.components import stt
from homeassistant.config_entries import ConfigEntry, ConfigSubentry
from homeassistant.core import HomeAssistant
from homeassistant.helpers.entity_platform import AddConfigEntryEntitiesCallback
from .const import (
CONF_CHAT_MODEL,
CONF_PROMPT,
DEFAULT_STT_PROMPT,
LOGGER,
RECOMMENDED_STT_MODEL,
)
from .entity import GoogleGenerativeAILLMBaseEntity
from .helpers import convert_to_wav
async def async_setup_entry(
hass: HomeAssistant,
config_entry: ConfigEntry,
async_add_entities: AddConfigEntryEntitiesCallback,
) -> None:
"""Set up STT entities."""
for subentry in config_entry.subentries.values():
if subentry.subentry_type != "stt":
continue
async_add_entities(
[GoogleGenerativeAISttEntity(config_entry, subentry)],
config_subentry_id=subentry.subentry_id,
)
class GoogleGenerativeAISttEntity(
stt.SpeechToTextEntity, GoogleGenerativeAILLMBaseEntity
):
"""Google Generative AI speech-to-text entity."""
def __init__(self, config_entry: ConfigEntry, subentry: ConfigSubentry) -> None:
"""Initialize the STT entity."""
super().__init__(config_entry, subentry, RECOMMENDED_STT_MODEL)
@property
def supported_languages(self) -> list[str]:
"""Return a list of supported languages."""
return [
"af-ZA",
"sq-AL",
"am-ET",
"ar-DZ",
"ar-BH",
"ar-EG",
"ar-IQ",
"ar-IL",
"ar-JO",
"ar-KW",
"ar-LB",
"ar-MA",
"ar-OM",
"ar-QA",
"ar-SA",
"ar-PS",
"ar-TN",
"ar-AE",
"ar-YE",
"hy-AM",
"az-AZ",
"eu-ES",
"bn-BD",
"bn-IN",
"bs-BA",
"bg-BG",
"my-MM",
"ca-ES",
"zh-CN",
"zh-TW",
"hr-HR",
"cs-CZ",
"da-DK",
"nl-BE",
"nl-NL",
"en-AU",
"en-CA",
"en-GH",
"en-HK",
"en-IN",
"en-IE",
"en-KE",
"en-NZ",
"en-NG",
"en-PK",
"en-PH",
"en-SG",
"en-ZA",
"en-TZ",
"en-GB",
"en-US",
"et-EE",
"fil-PH",
"fi-FI",
"fr-BE",
"fr-CA",
"fr-FR",
"fr-CH",
"gl-ES",
"ka-GE",
"de-AT",
"de-DE",
"de-CH",
"el-GR",
"gu-IN",
"iw-IL",
"hi-IN",
"hu-HU",
"is-IS",
"id-ID",
"it-IT",
"it-CH",
"ja-JP",
"jv-ID",
"kn-IN",
"kk-KZ",
"km-KH",
"ko-KR",
"lo-LA",
"lv-LV",
"lt-LT",
"mk-MK",
"ms-MY",
"ml-IN",
"mr-IN",
"mn-MN",
"ne-NP",
"no-NO",
"fa-IR",
"pl-PL",
"pt-BR",
"pt-PT",
"ro-RO",
"ru-RU",
"sr-RS",
"si-LK",
"sk-SK",
"sl-SI",
"es-AR",
"es-BO",
"es-CL",
"es-CO",
"es-CR",
"es-DO",
"es-EC",
"es-SV",
"es-GT",
"es-HN",
"es-MX",
"es-NI",
"es-PA",
"es-PY",
"es-PE",
"es-PR",
"es-ES",
"es-US",
"es-UY",
"es-VE",
"su-ID",
"sw-KE",
"sw-TZ",
"sv-SE",
"ta-IN",
"ta-MY",
"ta-SG",
"ta-LK",
"te-IN",
"th-TH",
"tr-TR",
"uk-UA",
"ur-IN",
"ur-PK",
"uz-UZ",
"vi-VN",
"zu-ZA",
]
@property
def supported_formats(self) -> list[stt.AudioFormats]:
"""Return a list of supported formats."""
# https://ai.google.dev/gemini-api/docs/audio#supported-formats
return [stt.AudioFormats.WAV, stt.AudioFormats.OGG]
@property
def supported_codecs(self) -> list[stt.AudioCodecs]:
"""Return a list of supported codecs."""
return [stt.AudioCodecs.PCM, stt.AudioCodecs.OPUS]
@property
def supported_bit_rates(self) -> list[stt.AudioBitRates]:
"""Return a list of supported bit rates."""
return [stt.AudioBitRates.BITRATE_16]
@property
def supported_sample_rates(self) -> list[stt.AudioSampleRates]:
"""Return a list of supported sample rates."""
return [stt.AudioSampleRates.SAMPLERATE_16000]
@property
def supported_channels(self) -> list[stt.AudioChannels]:
"""Return a list of supported channels."""
# Per https://ai.google.dev/gemini-api/docs/audio
# If the audio source contains multiple channels, Gemini combines those channels into a single channel.
return [stt.AudioChannels.CHANNEL_MONO]
async def async_process_audio_stream(
self, metadata: stt.SpeechMetadata, stream: AsyncIterable[bytes]
) -> stt.SpeechResult:
"""Process an audio stream to STT service."""
audio_data = b""
async for chunk in stream:
audio_data += chunk
if metadata.format == stt.AudioFormats.WAV:
audio_data = convert_to_wav(
audio_data,
f"audio/L{metadata.bit_rate.value};rate={metadata.sample_rate.value}",
)
try:
response = await self._genai_client.aio.models.generate_content(
model=self.subentry.data.get(CONF_CHAT_MODEL, RECOMMENDED_STT_MODEL),
contents=[
self.subentry.data.get(CONF_PROMPT, DEFAULT_STT_PROMPT),
Part.from_bytes(
data=audio_data,
mime_type=f"audio/{metadata.format.value}",
),
],
config=self.create_generate_content_config(),
)
except (APIError, ClientError, ValueError) as err:
LOGGER.error("Error during STT: %s", err)
else:
if response.text:
return stt.SpeechResult(
response.text,
stt.SpeechResultState.SUCCESS,
)
return stt.SpeechResult(None, stt.SpeechResultState.ERROR)

View File

@ -9,6 +9,7 @@ from homeassistant.components.google_generative_ai_conversation.const import (
CONF_USE_GOOGLE_SEARCH_TOOL,
DEFAULT_AI_TASK_NAME,
DEFAULT_CONVERSATION_NAME,
DEFAULT_STT_NAME,
DEFAULT_TTS_NAME,
)
from homeassistant.config_entries import ConfigEntry
@ -39,6 +40,13 @@ def mock_config_entry(hass: HomeAssistant) -> MockConfigEntry:
"subentry_id": "ulid-conversation",
"unique_id": None,
},
{
"data": {},
"subentry_type": "stt",
"title": DEFAULT_STT_NAME,
"subentry_id": "ulid-stt",
"unique_id": None,
},
{
"data": {},
"subentry_type": "tts",

View File

@ -34,6 +34,14 @@
'title': 'Google AI Conversation',
'unique_id': None,
}),
'ulid-stt': dict({
'data': dict({
}),
'subentry_id': 'ulid-stt',
'subentry_type': 'stt',
'title': 'Google AI STT',
'unique_id': None,
}),
'ulid-tts': dict({
'data': dict({
}),

View File

@ -32,6 +32,37 @@
'sw_version': None,
'via_device_id': None,
}),
DeviceRegistryEntrySnapshot({
'area_id': None,
'config_entries': <ANY>,
'config_entries_subentries': <ANY>,
'configuration_url': None,
'connections': set({
}),
'disabled_by': None,
'entry_type': <DeviceEntryType.SERVICE: 'service'>,
'hw_version': None,
'id': <ANY>,
'identifiers': set({
tuple(
'google_generative_ai_conversation',
'ulid-stt',
),
}),
'is_new': False,
'labels': set({
}),
'manufacturer': 'Google',
'model': 'gemini-2.5-flash',
'model_id': None,
'name': 'Google AI STT',
'name_by_user': None,
'primary_config_entry': <ANY>,
'serial_number': None,
'suggested_area': None,
'sw_version': None,
'via_device_id': None,
}),
DeviceRegistryEntrySnapshot({
'area_id': None,
'config_entries': <ANY>,

View File

@ -1,5 +1,6 @@
"""Test the Google Generative AI Conversation config flow."""
from typing import Any
from unittest.mock import Mock, patch
import pytest
@ -21,6 +22,7 @@ from homeassistant.components.google_generative_ai_conversation.const import (
CONF_USE_GOOGLE_SEARCH_TOOL,
DEFAULT_AI_TASK_NAME,
DEFAULT_CONVERSATION_NAME,
DEFAULT_STT_NAME,
DEFAULT_TTS_NAME,
DOMAIN,
RECOMMENDED_AI_TASK_OPTIONS,
@ -28,8 +30,11 @@ from homeassistant.components.google_generative_ai_conversation.const import (
RECOMMENDED_CONVERSATION_OPTIONS,
RECOMMENDED_HARM_BLOCK_THRESHOLD,
RECOMMENDED_MAX_TOKENS,
RECOMMENDED_STT_MODEL,
RECOMMENDED_STT_OPTIONS,
RECOMMENDED_TOP_K,
RECOMMENDED_TOP_P,
RECOMMENDED_TTS_MODEL,
RECOMMENDED_TTS_OPTIONS,
RECOMMENDED_USE_GOOGLE_SEARCH_TOOL,
)
@ -64,11 +69,17 @@ def get_models_pager():
)
model_15_pro.name = "models/gemini-1.5-pro-latest"
model_25_flash_tts = Mock(
supported_actions=["generateContent"],
)
model_25_flash_tts.name = "models/gemini-2.5-flash-preview-tts"
async def models_pager():
yield model_25_flash
yield model_20_flash
yield model_15_flash
yield model_15_pro
yield model_25_flash_tts
return models_pager()
@ -129,6 +140,12 @@ async def test_form(hass: HomeAssistant) -> None:
"title": DEFAULT_AI_TASK_NAME,
"unique_id": None,
},
{
"subentry_type": "stt",
"data": RECOMMENDED_STT_OPTIONS,
"title": DEFAULT_STT_NAME,
"unique_id": None,
},
]
assert len(mock_setup_entry.mock_calls) == 1
@ -157,22 +174,35 @@ async def test_duplicate_entry(hass: HomeAssistant) -> None:
assert result["reason"] == "already_configured"
async def test_creating_conversation_subentry(
@pytest.mark.parametrize(
("subentry_type", "options"),
[
("conversation", RECOMMENDED_CONVERSATION_OPTIONS),
("stt", RECOMMENDED_STT_OPTIONS),
("tts", RECOMMENDED_TTS_OPTIONS),
("ai_task_data", RECOMMENDED_AI_TASK_OPTIONS),
],
)
async def test_creating_subentry(
hass: HomeAssistant,
mock_init_component: None,
mock_config_entry: MockConfigEntry,
subentry_type: str,
options: dict[str, Any],
) -> None:
"""Test creating a conversation subentry."""
"""Test creating a subentry."""
old_subentries = set(mock_config_entry.subentries)
with patch(
"google.genai.models.AsyncModels.list",
return_value=get_models_pager(),
):
result = await hass.config_entries.subentries.async_init(
(mock_config_entry.entry_id, "conversation"),
(mock_config_entry.entry_id, subentry_type),
context={"source": config_entries.SOURCE_USER},
)
assert result["type"] is FlowResultType.FORM
assert result["type"] is FlowResultType.FORM, result
assert result["step_id"] == "set_options"
assert not result["errors"]
@ -182,31 +212,117 @@ async def test_creating_conversation_subentry(
):
result2 = await hass.config_entries.subentries.async_configure(
result["flow_id"],
{CONF_NAME: "Mock name", **RECOMMENDED_CONVERSATION_OPTIONS},
result["data_schema"]({CONF_NAME: "Mock name", **options}),
)
await hass.async_block_till_done()
expected_options = options.copy()
if CONF_PROMPT in expected_options:
expected_options[CONF_PROMPT] = expected_options[CONF_PROMPT].strip()
assert result2["type"] is FlowResultType.CREATE_ENTRY
assert result2["title"] == "Mock name"
assert result2["data"] == expected_options
processed_options = RECOMMENDED_CONVERSATION_OPTIONS.copy()
processed_options[CONF_PROMPT] = processed_options[CONF_PROMPT].strip()
assert len(mock_config_entry.subentries) == len(old_subentries) + 1
assert result2["data"] == processed_options
new_subentry_id = list(set(mock_config_entry.subentries) - old_subentries)[0]
new_subentry = mock_config_entry.subentries[new_subentry_id]
assert new_subentry.subentry_type == subentry_type
assert new_subentry.data == expected_options
assert new_subentry.title == "Mock name"
async def test_creating_tts_subentry(
@pytest.mark.parametrize(
("subentry_type", "recommended_model", "options"),
[
(
"conversation",
RECOMMENDED_CHAT_MODEL,
{
CONF_PROMPT: "You are Mario",
CONF_LLM_HASS_API: ["assist"],
CONF_RECOMMENDED: False,
CONF_CHAT_MODEL: RECOMMENDED_CHAT_MODEL,
CONF_TEMPERATURE: 1.0,
CONF_TOP_P: 1.0,
CONF_TOP_K: 1,
CONF_MAX_TOKENS: 1024,
CONF_HARASSMENT_BLOCK_THRESHOLD: "BLOCK_MEDIUM_AND_ABOVE",
CONF_HATE_BLOCK_THRESHOLD: "BLOCK_MEDIUM_AND_ABOVE",
CONF_SEXUAL_BLOCK_THRESHOLD: "BLOCK_MEDIUM_AND_ABOVE",
CONF_DANGEROUS_BLOCK_THRESHOLD: "BLOCK_MEDIUM_AND_ABOVE",
CONF_USE_GOOGLE_SEARCH_TOOL: RECOMMENDED_USE_GOOGLE_SEARCH_TOOL,
},
),
(
"stt",
RECOMMENDED_STT_MODEL,
{
CONF_PROMPT: "Transcribe this",
CONF_RECOMMENDED: False,
CONF_CHAT_MODEL: RECOMMENDED_STT_MODEL,
CONF_TEMPERATURE: 1.0,
CONF_TOP_P: 1.0,
CONF_TOP_K: 1,
CONF_MAX_TOKENS: 1024,
CONF_HARASSMENT_BLOCK_THRESHOLD: "BLOCK_MEDIUM_AND_ABOVE",
CONF_HATE_BLOCK_THRESHOLD: "BLOCK_MEDIUM_AND_ABOVE",
CONF_SEXUAL_BLOCK_THRESHOLD: "BLOCK_MEDIUM_AND_ABOVE",
CONF_DANGEROUS_BLOCK_THRESHOLD: "BLOCK_MEDIUM_AND_ABOVE",
},
),
(
"tts",
RECOMMENDED_TTS_MODEL,
{
CONF_RECOMMENDED: False,
CONF_CHAT_MODEL: RECOMMENDED_TTS_MODEL,
CONF_TEMPERATURE: 1.0,
CONF_TOP_P: 1.0,
CONF_TOP_K: 1,
CONF_MAX_TOKENS: 1024,
CONF_HARASSMENT_BLOCK_THRESHOLD: "BLOCK_MEDIUM_AND_ABOVE",
CONF_HATE_BLOCK_THRESHOLD: "BLOCK_MEDIUM_AND_ABOVE",
CONF_SEXUAL_BLOCK_THRESHOLD: "BLOCK_MEDIUM_AND_ABOVE",
CONF_DANGEROUS_BLOCK_THRESHOLD: "BLOCK_MEDIUM_AND_ABOVE",
},
),
(
"ai_task_data",
RECOMMENDED_CHAT_MODEL,
{
CONF_RECOMMENDED: False,
CONF_CHAT_MODEL: RECOMMENDED_CHAT_MODEL,
CONF_TEMPERATURE: 1.0,
CONF_TOP_P: 1.0,
CONF_TOP_K: 1,
CONF_MAX_TOKENS: 1024,
CONF_HARASSMENT_BLOCK_THRESHOLD: "BLOCK_MEDIUM_AND_ABOVE",
CONF_HATE_BLOCK_THRESHOLD: "BLOCK_MEDIUM_AND_ABOVE",
CONF_SEXUAL_BLOCK_THRESHOLD: "BLOCK_MEDIUM_AND_ABOVE",
CONF_DANGEROUS_BLOCK_THRESHOLD: "BLOCK_MEDIUM_AND_ABOVE",
},
),
],
)
async def test_creating_subentry_custom_options(
hass: HomeAssistant,
mock_init_component: None,
mock_config_entry: MockConfigEntry,
subentry_type: str,
recommended_model: str,
options: dict[str, Any],
) -> None:
"""Test creating a TTS subentry."""
"""Test creating a subentry with custom options."""
old_subentries = set(mock_config_entry.subentries)
with patch(
"google.genai.models.AsyncModels.list",
return_value=get_models_pager(),
):
result = await hass.config_entries.subentries.async_init(
(mock_config_entry.entry_id, "tts"),
(mock_config_entry.entry_id, subentry_type),
context={"source": config_entries.SOURCE_USER},
)
@ -214,75 +330,52 @@ async def test_creating_tts_subentry(
assert result["step_id"] == "set_options"
assert not result["errors"]
old_subentries = set(mock_config_entry.subentries)
# Uncheck recommended to show custom options
with patch(
"google.genai.models.AsyncModels.list",
return_value=get_models_pager(),
):
result2 = await hass.config_entries.subentries.async_configure(
result["flow_id"],
{CONF_NAME: "Mock TTS", **RECOMMENDED_TTS_OPTIONS},
result["data_schema"]({CONF_RECOMMENDED: False}),
)
await hass.async_block_till_done()
assert result2["type"] is FlowResultType.FORM
assert result2["type"] is FlowResultType.CREATE_ENTRY
assert result2["title"] == "Mock TTS"
assert result2["data"] == RECOMMENDED_TTS_OPTIONS
# Find the schema key for CONF_CHAT_MODEL and check its default
schema_dict = result2["data_schema"].schema
chat_model_key = next(key for key in schema_dict if key.schema == CONF_CHAT_MODEL)
assert chat_model_key.default() == recommended_model
models_in_selector = [
opt["value"] for opt in schema_dict[chat_model_key].config["options"]
]
assert recommended_model in models_in_selector
assert len(mock_config_entry.subentries) == 4
new_subentry_id = list(set(mock_config_entry.subentries) - old_subentries)[0]
new_subentry = mock_config_entry.subentries[new_subentry_id]
assert new_subentry.subentry_type == "tts"
assert new_subentry.data == RECOMMENDED_TTS_OPTIONS
assert new_subentry.title == "Mock TTS"
async def test_creating_ai_task_subentry(
hass: HomeAssistant,
mock_init_component: None,
mock_config_entry: MockConfigEntry,
) -> None:
"""Test creating an AI task subentry."""
# Submit the form
with patch(
"google.genai.models.AsyncModels.list",
return_value=get_models_pager(),
):
result = await hass.config_entries.subentries.async_init(
(mock_config_entry.entry_id, "ai_task_data"),
context={"source": config_entries.SOURCE_USER},
)
assert result["type"] is FlowResultType.FORM, result
assert result["step_id"] == "set_options"
assert not result["errors"]
old_subentries = set(mock_config_entry.subentries)
with patch(
"google.genai.models.AsyncModels.list",
return_value=get_models_pager(),
):
result2 = await hass.config_entries.subentries.async_configure(
result["flow_id"],
{CONF_NAME: "Mock AI Task", **RECOMMENDED_AI_TASK_OPTIONS},
result3 = await hass.config_entries.subentries.async_configure(
result2["flow_id"],
result2["data_schema"]({CONF_NAME: "Mock name", **options}),
)
await hass.async_block_till_done()
assert result2["type"] is FlowResultType.CREATE_ENTRY
assert result2["title"] == "Mock AI Task"
assert result2["data"] == RECOMMENDED_AI_TASK_OPTIONS
expected_options = options.copy()
if CONF_PROMPT in expected_options:
expected_options[CONF_PROMPT] = expected_options[CONF_PROMPT].strip()
assert result3["type"] is FlowResultType.CREATE_ENTRY
assert result3["title"] == "Mock name"
assert result3["data"] == expected_options
assert len(mock_config_entry.subentries) == 4
assert len(mock_config_entry.subentries) == len(old_subentries) + 1
new_subentry_id = list(set(mock_config_entry.subentries) - old_subentries)[0]
new_subentry = mock_config_entry.subentries[new_subentry_id]
assert new_subentry.subentry_type == "ai_task_data"
assert new_subentry.data == RECOMMENDED_AI_TASK_OPTIONS
assert new_subentry.title == "Mock AI Task"
assert new_subentry.subentry_type == subentry_type
assert new_subentry.data == expected_options
assert new_subentry.title == "Mock name"
async def test_creating_conversation_subentry_not_loaded(

View File

@ -11,11 +11,13 @@ from syrupy.assertion import SnapshotAssertion
from homeassistant.components.google_generative_ai_conversation.const import (
DEFAULT_AI_TASK_NAME,
DEFAULT_CONVERSATION_NAME,
DEFAULT_STT_NAME,
DEFAULT_TITLE,
DEFAULT_TTS_NAME,
DOMAIN,
RECOMMENDED_AI_TASK_OPTIONS,
RECOMMENDED_CONVERSATION_OPTIONS,
RECOMMENDED_STT_OPTIONS,
RECOMMENDED_TTS_OPTIONS,
)
from homeassistant.config_entries import (
@ -489,7 +491,7 @@ async def test_migration_from_v1(
assert entry.minor_version == 4
assert not entry.options
assert entry.title == DEFAULT_TITLE
assert len(entry.subentries) == 4
assert len(entry.subentries) == 5
conversation_subentries = [
subentry
for subentry in entry.subentries.values()
@ -516,6 +518,14 @@ async def test_migration_from_v1(
assert len(ai_task_subentries) == 1
assert ai_task_subentries[0].data == RECOMMENDED_AI_TASK_OPTIONS
assert ai_task_subentries[0].title == DEFAULT_AI_TASK_NAME
stt_subentries = [
subentry
for subentry in entry.subentries.values()
if subentry.subentry_type == "stt"
]
assert len(stt_subentries) == 1
assert stt_subentries[0].data == RECOMMENDED_STT_OPTIONS
assert stt_subentries[0].title == DEFAULT_STT_NAME
subentry = conversation_subentries[0]
@ -721,7 +731,7 @@ async def test_migration_from_v1_disabled(
assert entry.minor_version == 4
assert not entry.options
assert entry.title == DEFAULT_TITLE
assert len(entry.subentries) == 4
assert len(entry.subentries) == 5
conversation_subentries = [
subentry
for subentry in entry.subentries.values()
@ -748,6 +758,14 @@ async def test_migration_from_v1_disabled(
assert len(ai_task_subentries) == 1
assert ai_task_subentries[0].data == RECOMMENDED_AI_TASK_OPTIONS
assert ai_task_subentries[0].title == DEFAULT_AI_TASK_NAME
stt_subentries = [
subentry
for subentry in entry.subentries.values()
if subentry.subentry_type == "stt"
]
assert len(stt_subentries) == 1
assert stt_subentries[0].data == RECOMMENDED_STT_OPTIONS
assert stt_subentries[0].title == DEFAULT_STT_NAME
assert not device_registry.async_get_device(
identifiers={(DOMAIN, mock_config_entry.entry_id)}
@ -860,7 +878,7 @@ async def test_migration_from_v1_with_multiple_keys(
assert entry.minor_version == 4
assert not entry.options
assert entry.title == DEFAULT_TITLE
assert len(entry.subentries) == 3
assert len(entry.subentries) == 4
subentry = list(entry.subentries.values())[0]
assert subentry.subentry_type == "conversation"
assert subentry.data == options
@ -873,6 +891,10 @@ async def test_migration_from_v1_with_multiple_keys(
assert subentry.subentry_type == "ai_task_data"
assert subentry.data == RECOMMENDED_AI_TASK_OPTIONS
assert subentry.title == DEFAULT_AI_TASK_NAME
subentry = list(entry.subentries.values())[3]
assert subentry.subentry_type == "stt"
assert subentry.data == RECOMMENDED_STT_OPTIONS
assert subentry.title == DEFAULT_STT_NAME
dev = device_registry.async_get_device(
identifiers={(DOMAIN, list(entry.subentries.values())[0].subentry_id)}
@ -963,7 +985,7 @@ async def test_migration_from_v1_with_same_keys(
assert entry.minor_version == 4
assert not entry.options
assert entry.title == DEFAULT_TITLE
assert len(entry.subentries) == 4
assert len(entry.subentries) == 5
conversation_subentries = [
subentry
for subentry in entry.subentries.values()
@ -990,6 +1012,14 @@ async def test_migration_from_v1_with_same_keys(
assert len(ai_task_subentries) == 1
assert ai_task_subentries[0].data == RECOMMENDED_AI_TASK_OPTIONS
assert ai_task_subentries[0].title == DEFAULT_AI_TASK_NAME
stt_subentries = [
subentry
for subentry in entry.subentries.values()
if subentry.subentry_type == "stt"
]
assert len(stt_subentries) == 1
assert stt_subentries[0].data == RECOMMENDED_STT_OPTIONS
assert stt_subentries[0].title == DEFAULT_STT_NAME
subentry = conversation_subentries[0]
@ -1090,10 +1120,11 @@ async def test_migration_from_v2_1(
"""Test migration from version 2.1.
This tests we clean up the broken migration in Home Assistant Core
2025.7.0b0-2025.7.0b1 and add AI Task subentry:
2025.7.0b0-2025.7.0b1 and add AI Task and STT subentries:
- Fix device registry (Fixed in Home Assistant Core 2025.7.0b2)
- Add TTS subentry (Added in Home Assistant Core 2025.7.0b1)
- Add AI Task subentry (Added in version 2.3)
- Add STT subentry (Added in version 2.3)
"""
# Create a v2.1 config entry with 2 subentries, devices and entities
options = {
@ -1184,7 +1215,7 @@ async def test_migration_from_v2_1(
assert entry.minor_version == 4
assert not entry.options
assert entry.title == DEFAULT_TITLE
assert len(entry.subentries) == 4
assert len(entry.subentries) == 5
conversation_subentries = [
subentry
for subentry in entry.subentries.values()
@ -1211,6 +1242,14 @@ async def test_migration_from_v2_1(
assert len(ai_task_subentries) == 1
assert ai_task_subentries[0].data == RECOMMENDED_AI_TASK_OPTIONS
assert ai_task_subentries[0].title == DEFAULT_AI_TASK_NAME
stt_subentries = [
subentry
for subentry in entry.subentries.values()
if subentry.subentry_type == "stt"
]
assert len(stt_subentries) == 1
assert stt_subentries[0].data == RECOMMENDED_STT_OPTIONS
assert stt_subentries[0].title == DEFAULT_STT_NAME
subentry = conversation_subentries[0]
@ -1320,8 +1359,8 @@ async def test_migrate_entry_from_v2_2(hass: HomeAssistant) -> None:
assert entry.version == 2
assert entry.minor_version == 4
# Check we now have conversation, tts and ai_task_data subentries
assert len(entry.subentries) == 3
# Check we now have conversation, tts, stt, and ai_task_data subentries
assert len(entry.subentries) == 4
subentries = {
subentry.subentry_type: subentry for subentry in entry.subentries.values()
@ -1336,6 +1375,12 @@ async def test_migrate_entry_from_v2_2(hass: HomeAssistant) -> None:
assert ai_task_subentry.title == DEFAULT_AI_TASK_NAME
assert ai_task_subentry.data == RECOMMENDED_AI_TASK_OPTIONS
# Find and verify the stt subentry
ai_task_subentry = subentries["stt"]
assert ai_task_subentry is not None
assert ai_task_subentry.title == DEFAULT_STT_NAME
assert ai_task_subentry.data == RECOMMENDED_STT_OPTIONS
# Verify conversation subentry is still there and unchanged
conversation_subentry = subentries["conversation"]
assert conversation_subentry is not None

View File

@ -0,0 +1,303 @@
"""Tests for the Google Generative AI Conversation STT entity."""
from __future__ import annotations
from collections.abc import AsyncIterable, Generator
from unittest.mock import AsyncMock, Mock, patch
from google.genai import types
import pytest
from homeassistant.components import stt
from homeassistant.components.google_generative_ai_conversation.const import (
CONF_CHAT_MODEL,
CONF_PROMPT,
DEFAULT_STT_PROMPT,
DOMAIN,
RECOMMENDED_STT_MODEL,
)
from homeassistant.config_entries import ConfigSubentry
from homeassistant.const import CONF_API_KEY
from homeassistant.core import HomeAssistant
from . import API_ERROR_500, CLIENT_ERROR_BAD_REQUEST
from tests.common import MockConfigEntry
TEST_CHAT_MODEL = "models/gemini-2.5-flash"
TEST_PROMPT = "Please transcribe the audio."
async def _async_get_audio_stream(data: bytes) -> AsyncIterable[bytes]:
"""Yield the audio data."""
yield data
@pytest.fixture
def mock_genai_client() -> Generator[AsyncMock]:
"""Mock genai.Client."""
client = Mock()
client.aio.models.get = AsyncMock()
client.aio.models.generate_content = AsyncMock(
return_value=types.GenerateContentResponse(
candidates=[
{
"content": {
"parts": [{"text": "This is a test transcription."}],
"role": "model",
}
}
]
)
)
with patch(
"homeassistant.components.google_generative_ai_conversation.Client",
return_value=client,
) as mock_client:
yield mock_client.return_value
@pytest.fixture
async def setup_integration(
hass: HomeAssistant,
mock_genai_client: AsyncMock,
) -> None:
"""Set up the test environment."""
config_entry = MockConfigEntry(
domain=DOMAIN, data={CONF_API_KEY: "bla"}, version=2, minor_version=1
)
config_entry.add_to_hass(hass)
sub_entry = ConfigSubentry(
data={
CONF_CHAT_MODEL: TEST_CHAT_MODEL,
CONF_PROMPT: TEST_PROMPT,
},
subentry_type="stt",
title="Google AI STT",
unique_id=None,
)
config_entry.runtime_data = mock_genai_client
hass.config_entries.async_add_subentry(config_entry, sub_entry)
await hass.config_entries.async_setup(config_entry.entry_id)
await hass.async_block_till_done()
@pytest.mark.usefixtures("setup_integration")
async def test_stt_entity_properties(hass: HomeAssistant) -> None:
"""Test STT entity properties."""
entity: stt.SpeechToTextEntity = hass.data[stt.DOMAIN].get_entity(
"stt.google_ai_stt"
)
assert entity is not None
assert isinstance(entity.supported_languages, list)
assert stt.AudioFormats.WAV in entity.supported_formats
assert stt.AudioFormats.OGG in entity.supported_formats
assert stt.AudioCodecs.PCM in entity.supported_codecs
assert stt.AudioCodecs.OPUS in entity.supported_codecs
assert stt.AudioBitRates.BITRATE_16 in entity.supported_bit_rates
assert stt.AudioSampleRates.SAMPLERATE_16000 in entity.supported_sample_rates
assert stt.AudioChannels.CHANNEL_MONO in entity.supported_channels
@pytest.mark.parametrize(
("audio_format", "call_convert_to_wav"),
[
(stt.AudioFormats.WAV, True),
(stt.AudioFormats.OGG, False),
],
)
@pytest.mark.usefixtures("setup_integration")
async def test_stt_process_audio_stream_success(
hass: HomeAssistant,
mock_genai_client: AsyncMock,
audio_format: stt.AudioFormats,
call_convert_to_wav: bool,
) -> None:
"""Test STT processing audio stream successfully."""
entity = hass.data[stt.DOMAIN].get_entity("stt.google_ai_stt")
metadata = stt.SpeechMetadata(
language="en-US",
format=audio_format,
codec=stt.AudioCodecs.PCM,
bit_rate=stt.AudioBitRates.BITRATE_16,
sample_rate=stt.AudioSampleRates.SAMPLERATE_16000,
channel=stt.AudioChannels.CHANNEL_MONO,
)
audio_stream = _async_get_audio_stream(b"test_audio_bytes")
with patch(
"homeassistant.components.google_generative_ai_conversation.stt.convert_to_wav",
return_value=b"converted_wav_bytes",
) as mock_convert_to_wav:
result = await entity.async_process_audio_stream(metadata, audio_stream)
assert result.result == stt.SpeechResultState.SUCCESS
assert result.text == "This is a test transcription."
if call_convert_to_wav:
mock_convert_to_wav.assert_called_once_with(
b"test_audio_bytes", "audio/L16;rate=16000"
)
else:
mock_convert_to_wav.assert_not_called()
mock_genai_client.aio.models.generate_content.assert_called_once()
call_args = mock_genai_client.aio.models.generate_content.call_args
assert call_args.kwargs["model"] == TEST_CHAT_MODEL
contents = call_args.kwargs["contents"]
assert contents[0] == TEST_PROMPT
assert isinstance(contents[1], types.Part)
assert contents[1].inline_data.mime_type == f"audio/{audio_format.value}"
if call_convert_to_wav:
assert contents[1].inline_data.data == b"converted_wav_bytes"
else:
assert contents[1].inline_data.data == b"test_audio_bytes"
@pytest.mark.parametrize(
"side_effect",
[
API_ERROR_500,
CLIENT_ERROR_BAD_REQUEST,
ValueError("Test value error"),
],
)
@pytest.mark.usefixtures("setup_integration")
async def test_stt_process_audio_stream_api_error(
hass: HomeAssistant,
mock_genai_client: AsyncMock,
side_effect: Exception,
) -> None:
"""Test STT processing audio stream with API errors."""
entity = hass.data[stt.DOMAIN].get_entity("stt.google_ai_stt")
mock_genai_client.aio.models.generate_content.side_effect = side_effect
metadata = stt.SpeechMetadata(
language="en-US",
format=stt.AudioFormats.OGG,
codec=stt.AudioCodecs.OPUS,
bit_rate=stt.AudioBitRates.BITRATE_16,
sample_rate=stt.AudioSampleRates.SAMPLERATE_16000,
channel=stt.AudioChannels.CHANNEL_MONO,
)
audio_stream = _async_get_audio_stream(b"test_audio_bytes")
result = await entity.async_process_audio_stream(metadata, audio_stream)
assert result.result == stt.SpeechResultState.ERROR
assert result.text is None
@pytest.mark.usefixtures("setup_integration")
async def test_stt_process_audio_stream_empty_response(
hass: HomeAssistant,
mock_genai_client: AsyncMock,
) -> None:
"""Test STT processing with an empty response from the API."""
entity = hass.data[stt.DOMAIN].get_entity("stt.google_ai_stt")
mock_genai_client.aio.models.generate_content.return_value = (
types.GenerateContentResponse(candidates=[])
)
metadata = stt.SpeechMetadata(
language="en-US",
format=stt.AudioFormats.OGG,
codec=stt.AudioCodecs.OPUS,
bit_rate=stt.AudioBitRates.BITRATE_16,
sample_rate=stt.AudioSampleRates.SAMPLERATE_16000,
channel=stt.AudioChannels.CHANNEL_MONO,
)
audio_stream = _async_get_audio_stream(b"test_audio_bytes")
result = await entity.async_process_audio_stream(metadata, audio_stream)
assert result.result == stt.SpeechResultState.ERROR
assert result.text is None
@pytest.mark.usefixtures("mock_genai_client")
async def test_stt_uses_default_prompt(
hass: HomeAssistant,
mock_genai_client: AsyncMock,
) -> None:
"""Test that the default prompt is used if none is configured."""
config_entry = MockConfigEntry(
domain=DOMAIN, data={CONF_API_KEY: "bla"}, version=2, minor_version=1
)
config_entry.add_to_hass(hass)
config_entry.runtime_data = mock_genai_client
# Subentry with no prompt
sub_entry = ConfigSubentry(
data={CONF_CHAT_MODEL: TEST_CHAT_MODEL},
subentry_type="stt",
title="Google AI STT",
unique_id=None,
)
hass.config_entries.async_add_subentry(config_entry, sub_entry)
await hass.config_entries.async_setup(config_entry.entry_id)
await hass.async_block_till_done()
entity = hass.data[stt.DOMAIN].get_entity("stt.google_ai_stt")
metadata = stt.SpeechMetadata(
language="en-US",
format=stt.AudioFormats.OGG,
codec=stt.AudioCodecs.OPUS,
bit_rate=stt.AudioBitRates.BITRATE_16,
sample_rate=stt.AudioSampleRates.SAMPLERATE_16000,
channel=stt.AudioChannels.CHANNEL_MONO,
)
audio_stream = _async_get_audio_stream(b"test_audio_bytes")
await entity.async_process_audio_stream(metadata, audio_stream)
call_args = mock_genai_client.aio.models.generate_content.call_args
contents = call_args.kwargs["contents"]
assert contents[0] == DEFAULT_STT_PROMPT
@pytest.mark.usefixtures("mock_genai_client")
async def test_stt_uses_default_model(
hass: HomeAssistant,
mock_genai_client: AsyncMock,
) -> None:
"""Test that the default model is used if none is configured."""
config_entry = MockConfigEntry(
domain=DOMAIN, data={CONF_API_KEY: "bla"}, version=2, minor_version=1
)
config_entry.add_to_hass(hass)
config_entry.runtime_data = mock_genai_client
# Subentry with no model
sub_entry = ConfigSubentry(
data={CONF_PROMPT: TEST_PROMPT},
subentry_type="stt",
title="Google AI STT",
unique_id=None,
)
hass.config_entries.async_add_subentry(config_entry, sub_entry)
await hass.config_entries.async_setup(config_entry.entry_id)
await hass.async_block_till_done()
entity = hass.data[stt.DOMAIN].get_entity("stt.google_ai_stt")
metadata = stt.SpeechMetadata(
language="en-US",
format=stt.AudioFormats.OGG,
codec=stt.AudioCodecs.OPUS,
bit_rate=stt.AudioBitRates.BITRATE_16,
sample_rate=stt.AudioSampleRates.SAMPLERATE_16000,
channel=stt.AudioChannels.CHANNEL_MONO,
)
audio_stream = _async_get_audio_stream(b"test_audio_bytes")
await entity.async_process_audio_stream(metadata, audio_stream)
call_args = mock_genai_client.aio.models.generate_content.call_args
assert call_args.kwargs["model"] == RECOMMENDED_STT_MODEL