Add Google AI STT (#147563)

This commit is contained in:
tronikos 2025-07-16 05:11:29 -07:00 committed by GitHub
parent 26a9af7371
commit 02a11638b3
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
11 changed files with 897 additions and 74 deletions

View File

@ -36,12 +36,14 @@ from homeassistant.helpers.typing import ConfigType
from .const import ( from .const import (
CONF_PROMPT, CONF_PROMPT,
DEFAULT_AI_TASK_NAME, DEFAULT_AI_TASK_NAME,
DEFAULT_STT_NAME,
DEFAULT_TITLE, DEFAULT_TITLE,
DEFAULT_TTS_NAME, DEFAULT_TTS_NAME,
DOMAIN, DOMAIN,
LOGGER, LOGGER,
RECOMMENDED_AI_TASK_OPTIONS, RECOMMENDED_AI_TASK_OPTIONS,
RECOMMENDED_CHAT_MODEL, RECOMMENDED_CHAT_MODEL,
RECOMMENDED_STT_OPTIONS,
RECOMMENDED_TTS_OPTIONS, RECOMMENDED_TTS_OPTIONS,
TIMEOUT_MILLIS, TIMEOUT_MILLIS,
) )
@ -55,6 +57,7 @@ CONFIG_SCHEMA = cv.config_entry_only_config_schema(DOMAIN)
PLATFORMS = ( PLATFORMS = (
Platform.AI_TASK, Platform.AI_TASK,
Platform.CONVERSATION, Platform.CONVERSATION,
Platform.STT,
Platform.TTS, Platform.TTS,
) )
@ -301,7 +304,7 @@ async def async_migrate_integration(hass: HomeAssistant) -> None:
if not use_existing: if not use_existing:
await hass.config_entries.async_remove(entry.entry_id) await hass.config_entries.async_remove(entry.entry_id)
else: else:
_add_ai_task_subentry(hass, entry) _add_ai_task_and_stt_subentries(hass, entry)
hass.config_entries.async_update_entry( hass.config_entries.async_update_entry(
entry, entry,
title=DEFAULT_TITLE, title=DEFAULT_TITLE,
@ -350,8 +353,7 @@ async def async_migrate_entry(
hass.config_entries.async_update_entry(entry, minor_version=2) hass.config_entries.async_update_entry(entry, minor_version=2)
if entry.version == 2 and entry.minor_version == 2: if entry.version == 2 and entry.minor_version == 2:
# Add AI Task subentry with default options _add_ai_task_and_stt_subentries(hass, entry)
_add_ai_task_subentry(hass, entry)
hass.config_entries.async_update_entry(entry, minor_version=3) hass.config_entries.async_update_entry(entry, minor_version=3)
if entry.version == 2 and entry.minor_version == 3: if entry.version == 2 and entry.minor_version == 3:
@ -393,10 +395,10 @@ async def async_migrate_entry(
return True return True
def _add_ai_task_subentry( def _add_ai_task_and_stt_subentries(
hass: HomeAssistant, entry: GoogleGenerativeAIConfigEntry hass: HomeAssistant, entry: GoogleGenerativeAIConfigEntry
) -> None: ) -> None:
"""Add AI Task subentry to the config entry.""" """Add AI Task and STT subentries to the config entry."""
hass.config_entries.async_add_subentry( hass.config_entries.async_add_subentry(
entry, entry,
ConfigSubentry( ConfigSubentry(
@ -406,3 +408,12 @@ def _add_ai_task_subentry(
unique_id=None, unique_id=None,
), ),
) )
hass.config_entries.async_add_subentry(
entry,
ConfigSubentry(
data=MappingProxyType(RECOMMENDED_STT_OPTIONS),
subentry_type="stt",
title=DEFAULT_STT_NAME,
unique_id=None,
),
)

View File

@ -49,6 +49,8 @@ from .const import (
CONF_USE_GOOGLE_SEARCH_TOOL, CONF_USE_GOOGLE_SEARCH_TOOL,
DEFAULT_AI_TASK_NAME, DEFAULT_AI_TASK_NAME,
DEFAULT_CONVERSATION_NAME, DEFAULT_CONVERSATION_NAME,
DEFAULT_STT_NAME,
DEFAULT_STT_PROMPT,
DEFAULT_TITLE, DEFAULT_TITLE,
DEFAULT_TTS_NAME, DEFAULT_TTS_NAME,
DOMAIN, DOMAIN,
@ -57,6 +59,8 @@ from .const import (
RECOMMENDED_CONVERSATION_OPTIONS, RECOMMENDED_CONVERSATION_OPTIONS,
RECOMMENDED_HARM_BLOCK_THRESHOLD, RECOMMENDED_HARM_BLOCK_THRESHOLD,
RECOMMENDED_MAX_TOKENS, RECOMMENDED_MAX_TOKENS,
RECOMMENDED_STT_MODEL,
RECOMMENDED_STT_OPTIONS,
RECOMMENDED_TEMPERATURE, RECOMMENDED_TEMPERATURE,
RECOMMENDED_TOP_K, RECOMMENDED_TOP_K,
RECOMMENDED_TOP_P, RECOMMENDED_TOP_P,
@ -144,6 +148,12 @@ class GoogleGenerativeAIConfigFlow(ConfigFlow, domain=DOMAIN):
"title": DEFAULT_AI_TASK_NAME, "title": DEFAULT_AI_TASK_NAME,
"unique_id": None, "unique_id": None,
}, },
{
"subentry_type": "stt",
"data": RECOMMENDED_STT_OPTIONS,
"title": DEFAULT_STT_NAME,
"unique_id": None,
},
], ],
) )
return self.async_show_form( return self.async_show_form(
@ -191,6 +201,7 @@ class GoogleGenerativeAIConfigFlow(ConfigFlow, domain=DOMAIN):
"""Return subentries supported by this integration.""" """Return subentries supported by this integration."""
return { return {
"conversation": LLMSubentryFlowHandler, "conversation": LLMSubentryFlowHandler,
"stt": LLMSubentryFlowHandler,
"tts": LLMSubentryFlowHandler, "tts": LLMSubentryFlowHandler,
"ai_task_data": LLMSubentryFlowHandler, "ai_task_data": LLMSubentryFlowHandler,
} }
@ -228,6 +239,8 @@ class LLMSubentryFlowHandler(ConfigSubentryFlow):
options = RECOMMENDED_TTS_OPTIONS.copy() options = RECOMMENDED_TTS_OPTIONS.copy()
elif self._subentry_type == "ai_task_data": elif self._subentry_type == "ai_task_data":
options = RECOMMENDED_AI_TASK_OPTIONS.copy() options = RECOMMENDED_AI_TASK_OPTIONS.copy()
elif self._subentry_type == "stt":
options = RECOMMENDED_STT_OPTIONS.copy()
else: else:
options = RECOMMENDED_CONVERSATION_OPTIONS.copy() options = RECOMMENDED_CONVERSATION_OPTIONS.copy()
else: else:
@ -304,6 +317,8 @@ async def google_generative_ai_config_option_schema(
default_name = DEFAULT_TTS_NAME default_name = DEFAULT_TTS_NAME
elif subentry_type == "ai_task_data": elif subentry_type == "ai_task_data":
default_name = DEFAULT_AI_TASK_NAME default_name = DEFAULT_AI_TASK_NAME
elif subentry_type == "stt":
default_name = DEFAULT_STT_NAME
else: else:
default_name = DEFAULT_CONVERSATION_NAME default_name = DEFAULT_CONVERSATION_NAME
schema: dict[vol.Required | vol.Optional, Any] = { schema: dict[vol.Required | vol.Optional, Any] = {
@ -331,6 +346,17 @@ async def google_generative_ai_config_option_schema(
), ),
} }
) )
elif subentry_type == "stt":
schema.update(
{
vol.Optional(
CONF_PROMPT,
description={
"suggested_value": options.get(CONF_PROMPT, DEFAULT_STT_PROMPT)
},
): TemplateSelector(),
}
)
schema.update( schema.update(
{ {
@ -388,6 +414,8 @@ async def google_generative_ai_config_option_schema(
if subentry_type == "tts": if subentry_type == "tts":
default_model = RECOMMENDED_TTS_MODEL default_model = RECOMMENDED_TTS_MODEL
elif subentry_type == "stt":
default_model = RECOMMENDED_STT_MODEL
else: else:
default_model = RECOMMENDED_CHAT_MODEL default_model = RECOMMENDED_CHAT_MODEL

View File

@ -5,18 +5,23 @@ import logging
from homeassistant.const import CONF_LLM_HASS_API from homeassistant.const import CONF_LLM_HASS_API
from homeassistant.helpers import llm from homeassistant.helpers import llm
LOGGER = logging.getLogger(__package__)
DOMAIN = "google_generative_ai_conversation" DOMAIN = "google_generative_ai_conversation"
DEFAULT_TITLE = "Google Generative AI" DEFAULT_TITLE = "Google Generative AI"
LOGGER = logging.getLogger(__package__)
CONF_PROMPT = "prompt"
DEFAULT_CONVERSATION_NAME = "Google AI Conversation" DEFAULT_CONVERSATION_NAME = "Google AI Conversation"
DEFAULT_STT_NAME = "Google AI STT"
DEFAULT_TTS_NAME = "Google AI TTS" DEFAULT_TTS_NAME = "Google AI TTS"
DEFAULT_AI_TASK_NAME = "Google AI Task" DEFAULT_AI_TASK_NAME = "Google AI Task"
CONF_PROMPT = "prompt"
DEFAULT_STT_PROMPT = "Transcribe the attached audio"
CONF_RECOMMENDED = "recommended" CONF_RECOMMENDED = "recommended"
CONF_CHAT_MODEL = "chat_model" CONF_CHAT_MODEL = "chat_model"
RECOMMENDED_CHAT_MODEL = "models/gemini-2.5-flash" RECOMMENDED_CHAT_MODEL = "models/gemini-2.5-flash"
RECOMMENDED_STT_MODEL = RECOMMENDED_CHAT_MODEL
RECOMMENDED_TTS_MODEL = "models/gemini-2.5-flash-preview-tts" RECOMMENDED_TTS_MODEL = "models/gemini-2.5-flash-preview-tts"
CONF_TEMPERATURE = "temperature" CONF_TEMPERATURE = "temperature"
RECOMMENDED_TEMPERATURE = 1.0 RECOMMENDED_TEMPERATURE = 1.0
@ -43,6 +48,11 @@ RECOMMENDED_CONVERSATION_OPTIONS = {
CONF_RECOMMENDED: True, CONF_RECOMMENDED: True,
} }
RECOMMENDED_STT_OPTIONS = {
CONF_PROMPT: DEFAULT_STT_PROMPT,
CONF_RECOMMENDED: True,
}
RECOMMENDED_TTS_OPTIONS = { RECOMMENDED_TTS_OPTIONS = {
CONF_RECOMMENDED: True, CONF_RECOMMENDED: True,
} }

View File

@ -61,6 +61,38 @@
"invalid_google_search_option": "Google Search can only be enabled if nothing is selected in the \"Control Home Assistant\" setting." "invalid_google_search_option": "Google Search can only be enabled if nothing is selected in the \"Control Home Assistant\" setting."
} }
}, },
"stt": {
"initiate_flow": {
"user": "Add Speech-to-Text service",
"reconfigure": "Reconfigure Speech-to-Text service"
},
"entry_type": "Speech-to-Text",
"step": {
"set_options": {
"data": {
"name": "[%key:common::config_flow::data::name%]",
"recommended": "[%key:component::google_generative_ai_conversation::config_subentries::conversation::step::set_options::data::recommended%]",
"prompt": "Instructions",
"chat_model": "[%key:common::generic::model%]",
"temperature": "[%key:component::google_generative_ai_conversation::config_subentries::conversation::step::set_options::data::temperature%]",
"top_p": "[%key:component::google_generative_ai_conversation::config_subentries::conversation::step::set_options::data::top_p%]",
"top_k": "[%key:component::google_generative_ai_conversation::config_subentries::conversation::step::set_options::data::top_k%]",
"max_tokens": "[%key:component::google_generative_ai_conversation::config_subentries::conversation::step::set_options::data::max_tokens%]",
"harassment_block_threshold": "[%key:component::google_generative_ai_conversation::config_subentries::conversation::step::set_options::data::harassment_block_threshold%]",
"hate_block_threshold": "[%key:component::google_generative_ai_conversation::config_subentries::conversation::step::set_options::data::hate_block_threshold%]",
"sexual_block_threshold": "[%key:component::google_generative_ai_conversation::config_subentries::conversation::step::set_options::data::sexual_block_threshold%]",
"dangerous_block_threshold": "[%key:component::google_generative_ai_conversation::config_subentries::conversation::step::set_options::data::dangerous_block_threshold%]"
},
"data_description": {
"prompt": "Instruct how the LLM should transcribe the audio."
}
}
},
"abort": {
"entry_not_loaded": "[%key:component::google_generative_ai_conversation::config_subentries::conversation::abort::entry_not_loaded%]",
"reconfigure_successful": "[%key:common::config_flow::abort::reconfigure_successful%]"
}
},
"tts": { "tts": {
"initiate_flow": { "initiate_flow": {
"user": "Add Text-to-Speech service", "user": "Add Text-to-Speech service",

View File

@ -0,0 +1,254 @@
"""Speech to text support for Google Generative AI."""
from __future__ import annotations
from collections.abc import AsyncIterable
from google.genai.errors import APIError, ClientError
from google.genai.types import Part
from homeassistant.components import stt
from homeassistant.config_entries import ConfigEntry, ConfigSubentry
from homeassistant.core import HomeAssistant
from homeassistant.helpers.entity_platform import AddConfigEntryEntitiesCallback
from .const import (
CONF_CHAT_MODEL,
CONF_PROMPT,
DEFAULT_STT_PROMPT,
LOGGER,
RECOMMENDED_STT_MODEL,
)
from .entity import GoogleGenerativeAILLMBaseEntity
from .helpers import convert_to_wav
async def async_setup_entry(
hass: HomeAssistant,
config_entry: ConfigEntry,
async_add_entities: AddConfigEntryEntitiesCallback,
) -> None:
"""Set up STT entities."""
for subentry in config_entry.subentries.values():
if subentry.subentry_type != "stt":
continue
async_add_entities(
[GoogleGenerativeAISttEntity(config_entry, subentry)],
config_subentry_id=subentry.subentry_id,
)
class GoogleGenerativeAISttEntity(
stt.SpeechToTextEntity, GoogleGenerativeAILLMBaseEntity
):
"""Google Generative AI speech-to-text entity."""
def __init__(self, config_entry: ConfigEntry, subentry: ConfigSubentry) -> None:
"""Initialize the STT entity."""
super().__init__(config_entry, subentry, RECOMMENDED_STT_MODEL)
@property
def supported_languages(self) -> list[str]:
"""Return a list of supported languages."""
return [
"af-ZA",
"sq-AL",
"am-ET",
"ar-DZ",
"ar-BH",
"ar-EG",
"ar-IQ",
"ar-IL",
"ar-JO",
"ar-KW",
"ar-LB",
"ar-MA",
"ar-OM",
"ar-QA",
"ar-SA",
"ar-PS",
"ar-TN",
"ar-AE",
"ar-YE",
"hy-AM",
"az-AZ",
"eu-ES",
"bn-BD",
"bn-IN",
"bs-BA",
"bg-BG",
"my-MM",
"ca-ES",
"zh-CN",
"zh-TW",
"hr-HR",
"cs-CZ",
"da-DK",
"nl-BE",
"nl-NL",
"en-AU",
"en-CA",
"en-GH",
"en-HK",
"en-IN",
"en-IE",
"en-KE",
"en-NZ",
"en-NG",
"en-PK",
"en-PH",
"en-SG",
"en-ZA",
"en-TZ",
"en-GB",
"en-US",
"et-EE",
"fil-PH",
"fi-FI",
"fr-BE",
"fr-CA",
"fr-FR",
"fr-CH",
"gl-ES",
"ka-GE",
"de-AT",
"de-DE",
"de-CH",
"el-GR",
"gu-IN",
"iw-IL",
"hi-IN",
"hu-HU",
"is-IS",
"id-ID",
"it-IT",
"it-CH",
"ja-JP",
"jv-ID",
"kn-IN",
"kk-KZ",
"km-KH",
"ko-KR",
"lo-LA",
"lv-LV",
"lt-LT",
"mk-MK",
"ms-MY",
"ml-IN",
"mr-IN",
"mn-MN",
"ne-NP",
"no-NO",
"fa-IR",
"pl-PL",
"pt-BR",
"pt-PT",
"ro-RO",
"ru-RU",
"sr-RS",
"si-LK",
"sk-SK",
"sl-SI",
"es-AR",
"es-BO",
"es-CL",
"es-CO",
"es-CR",
"es-DO",
"es-EC",
"es-SV",
"es-GT",
"es-HN",
"es-MX",
"es-NI",
"es-PA",
"es-PY",
"es-PE",
"es-PR",
"es-ES",
"es-US",
"es-UY",
"es-VE",
"su-ID",
"sw-KE",
"sw-TZ",
"sv-SE",
"ta-IN",
"ta-MY",
"ta-SG",
"ta-LK",
"te-IN",
"th-TH",
"tr-TR",
"uk-UA",
"ur-IN",
"ur-PK",
"uz-UZ",
"vi-VN",
"zu-ZA",
]
@property
def supported_formats(self) -> list[stt.AudioFormats]:
"""Return a list of supported formats."""
# https://ai.google.dev/gemini-api/docs/audio#supported-formats
return [stt.AudioFormats.WAV, stt.AudioFormats.OGG]
@property
def supported_codecs(self) -> list[stt.AudioCodecs]:
"""Return a list of supported codecs."""
return [stt.AudioCodecs.PCM, stt.AudioCodecs.OPUS]
@property
def supported_bit_rates(self) -> list[stt.AudioBitRates]:
"""Return a list of supported bit rates."""
return [stt.AudioBitRates.BITRATE_16]
@property
def supported_sample_rates(self) -> list[stt.AudioSampleRates]:
"""Return a list of supported sample rates."""
return [stt.AudioSampleRates.SAMPLERATE_16000]
@property
def supported_channels(self) -> list[stt.AudioChannels]:
"""Return a list of supported channels."""
# Per https://ai.google.dev/gemini-api/docs/audio
# If the audio source contains multiple channels, Gemini combines those channels into a single channel.
return [stt.AudioChannels.CHANNEL_MONO]
async def async_process_audio_stream(
self, metadata: stt.SpeechMetadata, stream: AsyncIterable[bytes]
) -> stt.SpeechResult:
"""Process an audio stream to STT service."""
audio_data = b""
async for chunk in stream:
audio_data += chunk
if metadata.format == stt.AudioFormats.WAV:
audio_data = convert_to_wav(
audio_data,
f"audio/L{metadata.bit_rate.value};rate={metadata.sample_rate.value}",
)
try:
response = await self._genai_client.aio.models.generate_content(
model=self.subentry.data.get(CONF_CHAT_MODEL, RECOMMENDED_STT_MODEL),
contents=[
self.subentry.data.get(CONF_PROMPT, DEFAULT_STT_PROMPT),
Part.from_bytes(
data=audio_data,
mime_type=f"audio/{metadata.format.value}",
),
],
config=self.create_generate_content_config(),
)
except (APIError, ClientError, ValueError) as err:
LOGGER.error("Error during STT: %s", err)
else:
if response.text:
return stt.SpeechResult(
response.text,
stt.SpeechResultState.SUCCESS,
)
return stt.SpeechResult(None, stt.SpeechResultState.ERROR)

View File

@ -9,6 +9,7 @@ from homeassistant.components.google_generative_ai_conversation.const import (
CONF_USE_GOOGLE_SEARCH_TOOL, CONF_USE_GOOGLE_SEARCH_TOOL,
DEFAULT_AI_TASK_NAME, DEFAULT_AI_TASK_NAME,
DEFAULT_CONVERSATION_NAME, DEFAULT_CONVERSATION_NAME,
DEFAULT_STT_NAME,
DEFAULT_TTS_NAME, DEFAULT_TTS_NAME,
) )
from homeassistant.config_entries import ConfigEntry from homeassistant.config_entries import ConfigEntry
@ -39,6 +40,13 @@ def mock_config_entry(hass: HomeAssistant) -> MockConfigEntry:
"subentry_id": "ulid-conversation", "subentry_id": "ulid-conversation",
"unique_id": None, "unique_id": None,
}, },
{
"data": {},
"subentry_type": "stt",
"title": DEFAULT_STT_NAME,
"subentry_id": "ulid-stt",
"unique_id": None,
},
{ {
"data": {}, "data": {},
"subentry_type": "tts", "subentry_type": "tts",

View File

@ -34,6 +34,14 @@
'title': 'Google AI Conversation', 'title': 'Google AI Conversation',
'unique_id': None, 'unique_id': None,
}), }),
'ulid-stt': dict({
'data': dict({
}),
'subentry_id': 'ulid-stt',
'subentry_type': 'stt',
'title': 'Google AI STT',
'unique_id': None,
}),
'ulid-tts': dict({ 'ulid-tts': dict({
'data': dict({ 'data': dict({
}), }),

View File

@ -32,6 +32,37 @@
'sw_version': None, 'sw_version': None,
'via_device_id': None, 'via_device_id': None,
}), }),
DeviceRegistryEntrySnapshot({
'area_id': None,
'config_entries': <ANY>,
'config_entries_subentries': <ANY>,
'configuration_url': None,
'connections': set({
}),
'disabled_by': None,
'entry_type': <DeviceEntryType.SERVICE: 'service'>,
'hw_version': None,
'id': <ANY>,
'identifiers': set({
tuple(
'google_generative_ai_conversation',
'ulid-stt',
),
}),
'is_new': False,
'labels': set({
}),
'manufacturer': 'Google',
'model': 'gemini-2.5-flash',
'model_id': None,
'name': 'Google AI STT',
'name_by_user': None,
'primary_config_entry': <ANY>,
'serial_number': None,
'suggested_area': None,
'sw_version': None,
'via_device_id': None,
}),
DeviceRegistryEntrySnapshot({ DeviceRegistryEntrySnapshot({
'area_id': None, 'area_id': None,
'config_entries': <ANY>, 'config_entries': <ANY>,

View File

@ -1,5 +1,6 @@
"""Test the Google Generative AI Conversation config flow.""" """Test the Google Generative AI Conversation config flow."""
from typing import Any
from unittest.mock import Mock, patch from unittest.mock import Mock, patch
import pytest import pytest
@ -21,6 +22,7 @@ from homeassistant.components.google_generative_ai_conversation.const import (
CONF_USE_GOOGLE_SEARCH_TOOL, CONF_USE_GOOGLE_SEARCH_TOOL,
DEFAULT_AI_TASK_NAME, DEFAULT_AI_TASK_NAME,
DEFAULT_CONVERSATION_NAME, DEFAULT_CONVERSATION_NAME,
DEFAULT_STT_NAME,
DEFAULT_TTS_NAME, DEFAULT_TTS_NAME,
DOMAIN, DOMAIN,
RECOMMENDED_AI_TASK_OPTIONS, RECOMMENDED_AI_TASK_OPTIONS,
@ -28,8 +30,11 @@ from homeassistant.components.google_generative_ai_conversation.const import (
RECOMMENDED_CONVERSATION_OPTIONS, RECOMMENDED_CONVERSATION_OPTIONS,
RECOMMENDED_HARM_BLOCK_THRESHOLD, RECOMMENDED_HARM_BLOCK_THRESHOLD,
RECOMMENDED_MAX_TOKENS, RECOMMENDED_MAX_TOKENS,
RECOMMENDED_STT_MODEL,
RECOMMENDED_STT_OPTIONS,
RECOMMENDED_TOP_K, RECOMMENDED_TOP_K,
RECOMMENDED_TOP_P, RECOMMENDED_TOP_P,
RECOMMENDED_TTS_MODEL,
RECOMMENDED_TTS_OPTIONS, RECOMMENDED_TTS_OPTIONS,
RECOMMENDED_USE_GOOGLE_SEARCH_TOOL, RECOMMENDED_USE_GOOGLE_SEARCH_TOOL,
) )
@ -64,11 +69,17 @@ def get_models_pager():
) )
model_15_pro.name = "models/gemini-1.5-pro-latest" model_15_pro.name = "models/gemini-1.5-pro-latest"
model_25_flash_tts = Mock(
supported_actions=["generateContent"],
)
model_25_flash_tts.name = "models/gemini-2.5-flash-preview-tts"
async def models_pager(): async def models_pager():
yield model_25_flash yield model_25_flash
yield model_20_flash yield model_20_flash
yield model_15_flash yield model_15_flash
yield model_15_pro yield model_15_pro
yield model_25_flash_tts
return models_pager() return models_pager()
@ -129,6 +140,12 @@ async def test_form(hass: HomeAssistant) -> None:
"title": DEFAULT_AI_TASK_NAME, "title": DEFAULT_AI_TASK_NAME,
"unique_id": None, "unique_id": None,
}, },
{
"subentry_type": "stt",
"data": RECOMMENDED_STT_OPTIONS,
"title": DEFAULT_STT_NAME,
"unique_id": None,
},
] ]
assert len(mock_setup_entry.mock_calls) == 1 assert len(mock_setup_entry.mock_calls) == 1
@ -157,22 +174,35 @@ async def test_duplicate_entry(hass: HomeAssistant) -> None:
assert result["reason"] == "already_configured" assert result["reason"] == "already_configured"
async def test_creating_conversation_subentry( @pytest.mark.parametrize(
("subentry_type", "options"),
[
("conversation", RECOMMENDED_CONVERSATION_OPTIONS),
("stt", RECOMMENDED_STT_OPTIONS),
("tts", RECOMMENDED_TTS_OPTIONS),
("ai_task_data", RECOMMENDED_AI_TASK_OPTIONS),
],
)
async def test_creating_subentry(
hass: HomeAssistant, hass: HomeAssistant,
mock_init_component: None, mock_init_component: None,
mock_config_entry: MockConfigEntry, mock_config_entry: MockConfigEntry,
subentry_type: str,
options: dict[str, Any],
) -> None: ) -> None:
"""Test creating a conversation subentry.""" """Test creating a subentry."""
old_subentries = set(mock_config_entry.subentries)
with patch( with patch(
"google.genai.models.AsyncModels.list", "google.genai.models.AsyncModels.list",
return_value=get_models_pager(), return_value=get_models_pager(),
): ):
result = await hass.config_entries.subentries.async_init( result = await hass.config_entries.subentries.async_init(
(mock_config_entry.entry_id, "conversation"), (mock_config_entry.entry_id, subentry_type),
context={"source": config_entries.SOURCE_USER}, context={"source": config_entries.SOURCE_USER},
) )
assert result["type"] is FlowResultType.FORM assert result["type"] is FlowResultType.FORM, result
assert result["step_id"] == "set_options" assert result["step_id"] == "set_options"
assert not result["errors"] assert not result["errors"]
@ -182,31 +212,117 @@ async def test_creating_conversation_subentry(
): ):
result2 = await hass.config_entries.subentries.async_configure( result2 = await hass.config_entries.subentries.async_configure(
result["flow_id"], result["flow_id"],
{CONF_NAME: "Mock name", **RECOMMENDED_CONVERSATION_OPTIONS}, result["data_schema"]({CONF_NAME: "Mock name", **options}),
) )
await hass.async_block_till_done() await hass.async_block_till_done()
expected_options = options.copy()
if CONF_PROMPT in expected_options:
expected_options[CONF_PROMPT] = expected_options[CONF_PROMPT].strip()
assert result2["type"] is FlowResultType.CREATE_ENTRY assert result2["type"] is FlowResultType.CREATE_ENTRY
assert result2["title"] == "Mock name" assert result2["title"] == "Mock name"
assert result2["data"] == expected_options
processed_options = RECOMMENDED_CONVERSATION_OPTIONS.copy() assert len(mock_config_entry.subentries) == len(old_subentries) + 1
processed_options[CONF_PROMPT] = processed_options[CONF_PROMPT].strip()
assert result2["data"] == processed_options new_subentry_id = list(set(mock_config_entry.subentries) - old_subentries)[0]
new_subentry = mock_config_entry.subentries[new_subentry_id]
assert new_subentry.subentry_type == subentry_type
assert new_subentry.data == expected_options
assert new_subentry.title == "Mock name"
async def test_creating_tts_subentry( @pytest.mark.parametrize(
("subentry_type", "recommended_model", "options"),
[
(
"conversation",
RECOMMENDED_CHAT_MODEL,
{
CONF_PROMPT: "You are Mario",
CONF_LLM_HASS_API: ["assist"],
CONF_RECOMMENDED: False,
CONF_CHAT_MODEL: RECOMMENDED_CHAT_MODEL,
CONF_TEMPERATURE: 1.0,
CONF_TOP_P: 1.0,
CONF_TOP_K: 1,
CONF_MAX_TOKENS: 1024,
CONF_HARASSMENT_BLOCK_THRESHOLD: "BLOCK_MEDIUM_AND_ABOVE",
CONF_HATE_BLOCK_THRESHOLD: "BLOCK_MEDIUM_AND_ABOVE",
CONF_SEXUAL_BLOCK_THRESHOLD: "BLOCK_MEDIUM_AND_ABOVE",
CONF_DANGEROUS_BLOCK_THRESHOLD: "BLOCK_MEDIUM_AND_ABOVE",
CONF_USE_GOOGLE_SEARCH_TOOL: RECOMMENDED_USE_GOOGLE_SEARCH_TOOL,
},
),
(
"stt",
RECOMMENDED_STT_MODEL,
{
CONF_PROMPT: "Transcribe this",
CONF_RECOMMENDED: False,
CONF_CHAT_MODEL: RECOMMENDED_STT_MODEL,
CONF_TEMPERATURE: 1.0,
CONF_TOP_P: 1.0,
CONF_TOP_K: 1,
CONF_MAX_TOKENS: 1024,
CONF_HARASSMENT_BLOCK_THRESHOLD: "BLOCK_MEDIUM_AND_ABOVE",
CONF_HATE_BLOCK_THRESHOLD: "BLOCK_MEDIUM_AND_ABOVE",
CONF_SEXUAL_BLOCK_THRESHOLD: "BLOCK_MEDIUM_AND_ABOVE",
CONF_DANGEROUS_BLOCK_THRESHOLD: "BLOCK_MEDIUM_AND_ABOVE",
},
),
(
"tts",
RECOMMENDED_TTS_MODEL,
{
CONF_RECOMMENDED: False,
CONF_CHAT_MODEL: RECOMMENDED_TTS_MODEL,
CONF_TEMPERATURE: 1.0,
CONF_TOP_P: 1.0,
CONF_TOP_K: 1,
CONF_MAX_TOKENS: 1024,
CONF_HARASSMENT_BLOCK_THRESHOLD: "BLOCK_MEDIUM_AND_ABOVE",
CONF_HATE_BLOCK_THRESHOLD: "BLOCK_MEDIUM_AND_ABOVE",
CONF_SEXUAL_BLOCK_THRESHOLD: "BLOCK_MEDIUM_AND_ABOVE",
CONF_DANGEROUS_BLOCK_THRESHOLD: "BLOCK_MEDIUM_AND_ABOVE",
},
),
(
"ai_task_data",
RECOMMENDED_CHAT_MODEL,
{
CONF_RECOMMENDED: False,
CONF_CHAT_MODEL: RECOMMENDED_CHAT_MODEL,
CONF_TEMPERATURE: 1.0,
CONF_TOP_P: 1.0,
CONF_TOP_K: 1,
CONF_MAX_TOKENS: 1024,
CONF_HARASSMENT_BLOCK_THRESHOLD: "BLOCK_MEDIUM_AND_ABOVE",
CONF_HATE_BLOCK_THRESHOLD: "BLOCK_MEDIUM_AND_ABOVE",
CONF_SEXUAL_BLOCK_THRESHOLD: "BLOCK_MEDIUM_AND_ABOVE",
CONF_DANGEROUS_BLOCK_THRESHOLD: "BLOCK_MEDIUM_AND_ABOVE",
},
),
],
)
async def test_creating_subentry_custom_options(
hass: HomeAssistant, hass: HomeAssistant,
mock_init_component: None, mock_init_component: None,
mock_config_entry: MockConfigEntry, mock_config_entry: MockConfigEntry,
subentry_type: str,
recommended_model: str,
options: dict[str, Any],
) -> None: ) -> None:
"""Test creating a TTS subentry.""" """Test creating a subentry with custom options."""
old_subentries = set(mock_config_entry.subentries)
with patch( with patch(
"google.genai.models.AsyncModels.list", "google.genai.models.AsyncModels.list",
return_value=get_models_pager(), return_value=get_models_pager(),
): ):
result = await hass.config_entries.subentries.async_init( result = await hass.config_entries.subentries.async_init(
(mock_config_entry.entry_id, "tts"), (mock_config_entry.entry_id, subentry_type),
context={"source": config_entries.SOURCE_USER}, context={"source": config_entries.SOURCE_USER},
) )
@ -214,75 +330,52 @@ async def test_creating_tts_subentry(
assert result["step_id"] == "set_options" assert result["step_id"] == "set_options"
assert not result["errors"] assert not result["errors"]
old_subentries = set(mock_config_entry.subentries) # Uncheck recommended to show custom options
with patch( with patch(
"google.genai.models.AsyncModels.list", "google.genai.models.AsyncModels.list",
return_value=get_models_pager(), return_value=get_models_pager(),
): ):
result2 = await hass.config_entries.subentries.async_configure( result2 = await hass.config_entries.subentries.async_configure(
result["flow_id"], result["flow_id"],
{CONF_NAME: "Mock TTS", **RECOMMENDED_TTS_OPTIONS}, result["data_schema"]({CONF_RECOMMENDED: False}),
) )
await hass.async_block_till_done() assert result2["type"] is FlowResultType.FORM
assert result2["type"] is FlowResultType.CREATE_ENTRY # Find the schema key for CONF_CHAT_MODEL and check its default
assert result2["title"] == "Mock TTS" schema_dict = result2["data_schema"].schema
assert result2["data"] == RECOMMENDED_TTS_OPTIONS chat_model_key = next(key for key in schema_dict if key.schema == CONF_CHAT_MODEL)
assert chat_model_key.default() == recommended_model
models_in_selector = [
opt["value"] for opt in schema_dict[chat_model_key].config["options"]
]
assert recommended_model in models_in_selector
assert len(mock_config_entry.subentries) == 4 # Submit the form
new_subentry_id = list(set(mock_config_entry.subentries) - old_subentries)[0]
new_subentry = mock_config_entry.subentries[new_subentry_id]
assert new_subentry.subentry_type == "tts"
assert new_subentry.data == RECOMMENDED_TTS_OPTIONS
assert new_subentry.title == "Mock TTS"
async def test_creating_ai_task_subentry(
hass: HomeAssistant,
mock_init_component: None,
mock_config_entry: MockConfigEntry,
) -> None:
"""Test creating an AI task subentry."""
with patch( with patch(
"google.genai.models.AsyncModels.list", "google.genai.models.AsyncModels.list",
return_value=get_models_pager(), return_value=get_models_pager(),
): ):
result = await hass.config_entries.subentries.async_init( result3 = await hass.config_entries.subentries.async_configure(
(mock_config_entry.entry_id, "ai_task_data"), result2["flow_id"],
context={"source": config_entries.SOURCE_USER}, result2["data_schema"]({CONF_NAME: "Mock name", **options}),
)
assert result["type"] is FlowResultType.FORM, result
assert result["step_id"] == "set_options"
assert not result["errors"]
old_subentries = set(mock_config_entry.subentries)
with patch(
"google.genai.models.AsyncModels.list",
return_value=get_models_pager(),
):
result2 = await hass.config_entries.subentries.async_configure(
result["flow_id"],
{CONF_NAME: "Mock AI Task", **RECOMMENDED_AI_TASK_OPTIONS},
) )
await hass.async_block_till_done() await hass.async_block_till_done()
assert result2["type"] is FlowResultType.CREATE_ENTRY expected_options = options.copy()
assert result2["title"] == "Mock AI Task" if CONF_PROMPT in expected_options:
assert result2["data"] == RECOMMENDED_AI_TASK_OPTIONS expected_options[CONF_PROMPT] = expected_options[CONF_PROMPT].strip()
assert result3["type"] is FlowResultType.CREATE_ENTRY
assert result3["title"] == "Mock name"
assert result3["data"] == expected_options
assert len(mock_config_entry.subentries) == 4 assert len(mock_config_entry.subentries) == len(old_subentries) + 1
new_subentry_id = list(set(mock_config_entry.subentries) - old_subentries)[0] new_subentry_id = list(set(mock_config_entry.subentries) - old_subentries)[0]
new_subentry = mock_config_entry.subentries[new_subentry_id] new_subentry = mock_config_entry.subentries[new_subentry_id]
assert new_subentry.subentry_type == "ai_task_data" assert new_subentry.subentry_type == subentry_type
assert new_subentry.data == RECOMMENDED_AI_TASK_OPTIONS assert new_subentry.data == expected_options
assert new_subentry.title == "Mock AI Task" assert new_subentry.title == "Mock name"
async def test_creating_conversation_subentry_not_loaded( async def test_creating_conversation_subentry_not_loaded(

View File

@ -11,11 +11,13 @@ from syrupy.assertion import SnapshotAssertion
from homeassistant.components.google_generative_ai_conversation.const import ( from homeassistant.components.google_generative_ai_conversation.const import (
DEFAULT_AI_TASK_NAME, DEFAULT_AI_TASK_NAME,
DEFAULT_CONVERSATION_NAME, DEFAULT_CONVERSATION_NAME,
DEFAULT_STT_NAME,
DEFAULT_TITLE, DEFAULT_TITLE,
DEFAULT_TTS_NAME, DEFAULT_TTS_NAME,
DOMAIN, DOMAIN,
RECOMMENDED_AI_TASK_OPTIONS, RECOMMENDED_AI_TASK_OPTIONS,
RECOMMENDED_CONVERSATION_OPTIONS, RECOMMENDED_CONVERSATION_OPTIONS,
RECOMMENDED_STT_OPTIONS,
RECOMMENDED_TTS_OPTIONS, RECOMMENDED_TTS_OPTIONS,
) )
from homeassistant.config_entries import ( from homeassistant.config_entries import (
@ -489,7 +491,7 @@ async def test_migration_from_v1(
assert entry.minor_version == 4 assert entry.minor_version == 4
assert not entry.options assert not entry.options
assert entry.title == DEFAULT_TITLE assert entry.title == DEFAULT_TITLE
assert len(entry.subentries) == 4 assert len(entry.subentries) == 5
conversation_subentries = [ conversation_subentries = [
subentry subentry
for subentry in entry.subentries.values() for subentry in entry.subentries.values()
@ -516,6 +518,14 @@ async def test_migration_from_v1(
assert len(ai_task_subentries) == 1 assert len(ai_task_subentries) == 1
assert ai_task_subentries[0].data == RECOMMENDED_AI_TASK_OPTIONS assert ai_task_subentries[0].data == RECOMMENDED_AI_TASK_OPTIONS
assert ai_task_subentries[0].title == DEFAULT_AI_TASK_NAME assert ai_task_subentries[0].title == DEFAULT_AI_TASK_NAME
stt_subentries = [
subentry
for subentry in entry.subentries.values()
if subentry.subentry_type == "stt"
]
assert len(stt_subentries) == 1
assert stt_subentries[0].data == RECOMMENDED_STT_OPTIONS
assert stt_subentries[0].title == DEFAULT_STT_NAME
subentry = conversation_subentries[0] subentry = conversation_subentries[0]
@ -721,7 +731,7 @@ async def test_migration_from_v1_disabled(
assert entry.minor_version == 4 assert entry.minor_version == 4
assert not entry.options assert not entry.options
assert entry.title == DEFAULT_TITLE assert entry.title == DEFAULT_TITLE
assert len(entry.subentries) == 4 assert len(entry.subentries) == 5
conversation_subentries = [ conversation_subentries = [
subentry subentry
for subentry in entry.subentries.values() for subentry in entry.subentries.values()
@ -748,6 +758,14 @@ async def test_migration_from_v1_disabled(
assert len(ai_task_subentries) == 1 assert len(ai_task_subentries) == 1
assert ai_task_subentries[0].data == RECOMMENDED_AI_TASK_OPTIONS assert ai_task_subentries[0].data == RECOMMENDED_AI_TASK_OPTIONS
assert ai_task_subentries[0].title == DEFAULT_AI_TASK_NAME assert ai_task_subentries[0].title == DEFAULT_AI_TASK_NAME
stt_subentries = [
subentry
for subentry in entry.subentries.values()
if subentry.subentry_type == "stt"
]
assert len(stt_subentries) == 1
assert stt_subentries[0].data == RECOMMENDED_STT_OPTIONS
assert stt_subentries[0].title == DEFAULT_STT_NAME
assert not device_registry.async_get_device( assert not device_registry.async_get_device(
identifiers={(DOMAIN, mock_config_entry.entry_id)} identifiers={(DOMAIN, mock_config_entry.entry_id)}
@ -860,7 +878,7 @@ async def test_migration_from_v1_with_multiple_keys(
assert entry.minor_version == 4 assert entry.minor_version == 4
assert not entry.options assert not entry.options
assert entry.title == DEFAULT_TITLE assert entry.title == DEFAULT_TITLE
assert len(entry.subentries) == 3 assert len(entry.subentries) == 4
subentry = list(entry.subentries.values())[0] subentry = list(entry.subentries.values())[0]
assert subentry.subentry_type == "conversation" assert subentry.subentry_type == "conversation"
assert subentry.data == options assert subentry.data == options
@ -873,6 +891,10 @@ async def test_migration_from_v1_with_multiple_keys(
assert subentry.subentry_type == "ai_task_data" assert subentry.subentry_type == "ai_task_data"
assert subentry.data == RECOMMENDED_AI_TASK_OPTIONS assert subentry.data == RECOMMENDED_AI_TASK_OPTIONS
assert subentry.title == DEFAULT_AI_TASK_NAME assert subentry.title == DEFAULT_AI_TASK_NAME
subentry = list(entry.subentries.values())[3]
assert subentry.subentry_type == "stt"
assert subentry.data == RECOMMENDED_STT_OPTIONS
assert subentry.title == DEFAULT_STT_NAME
dev = device_registry.async_get_device( dev = device_registry.async_get_device(
identifiers={(DOMAIN, list(entry.subentries.values())[0].subentry_id)} identifiers={(DOMAIN, list(entry.subentries.values())[0].subentry_id)}
@ -963,7 +985,7 @@ async def test_migration_from_v1_with_same_keys(
assert entry.minor_version == 4 assert entry.minor_version == 4
assert not entry.options assert not entry.options
assert entry.title == DEFAULT_TITLE assert entry.title == DEFAULT_TITLE
assert len(entry.subentries) == 4 assert len(entry.subentries) == 5
conversation_subentries = [ conversation_subentries = [
subentry subentry
for subentry in entry.subentries.values() for subentry in entry.subentries.values()
@ -990,6 +1012,14 @@ async def test_migration_from_v1_with_same_keys(
assert len(ai_task_subentries) == 1 assert len(ai_task_subentries) == 1
assert ai_task_subentries[0].data == RECOMMENDED_AI_TASK_OPTIONS assert ai_task_subentries[0].data == RECOMMENDED_AI_TASK_OPTIONS
assert ai_task_subentries[0].title == DEFAULT_AI_TASK_NAME assert ai_task_subentries[0].title == DEFAULT_AI_TASK_NAME
stt_subentries = [
subentry
for subentry in entry.subentries.values()
if subentry.subentry_type == "stt"
]
assert len(stt_subentries) == 1
assert stt_subentries[0].data == RECOMMENDED_STT_OPTIONS
assert stt_subentries[0].title == DEFAULT_STT_NAME
subentry = conversation_subentries[0] subentry = conversation_subentries[0]
@ -1090,10 +1120,11 @@ async def test_migration_from_v2_1(
"""Test migration from version 2.1. """Test migration from version 2.1.
This tests we clean up the broken migration in Home Assistant Core This tests we clean up the broken migration in Home Assistant Core
2025.7.0b0-2025.7.0b1 and add AI Task subentry: 2025.7.0b0-2025.7.0b1 and add AI Task and STT subentries:
- Fix device registry (Fixed in Home Assistant Core 2025.7.0b2) - Fix device registry (Fixed in Home Assistant Core 2025.7.0b2)
- Add TTS subentry (Added in Home Assistant Core 2025.7.0b1) - Add TTS subentry (Added in Home Assistant Core 2025.7.0b1)
- Add AI Task subentry (Added in version 2.3) - Add AI Task subentry (Added in version 2.3)
- Add STT subentry (Added in version 2.3)
""" """
# Create a v2.1 config entry with 2 subentries, devices and entities # Create a v2.1 config entry with 2 subentries, devices and entities
options = { options = {
@ -1184,7 +1215,7 @@ async def test_migration_from_v2_1(
assert entry.minor_version == 4 assert entry.minor_version == 4
assert not entry.options assert not entry.options
assert entry.title == DEFAULT_TITLE assert entry.title == DEFAULT_TITLE
assert len(entry.subentries) == 4 assert len(entry.subentries) == 5
conversation_subentries = [ conversation_subentries = [
subentry subentry
for subentry in entry.subentries.values() for subentry in entry.subentries.values()
@ -1211,6 +1242,14 @@ async def test_migration_from_v2_1(
assert len(ai_task_subentries) == 1 assert len(ai_task_subentries) == 1
assert ai_task_subentries[0].data == RECOMMENDED_AI_TASK_OPTIONS assert ai_task_subentries[0].data == RECOMMENDED_AI_TASK_OPTIONS
assert ai_task_subentries[0].title == DEFAULT_AI_TASK_NAME assert ai_task_subentries[0].title == DEFAULT_AI_TASK_NAME
stt_subentries = [
subentry
for subentry in entry.subentries.values()
if subentry.subentry_type == "stt"
]
assert len(stt_subentries) == 1
assert stt_subentries[0].data == RECOMMENDED_STT_OPTIONS
assert stt_subentries[0].title == DEFAULT_STT_NAME
subentry = conversation_subentries[0] subentry = conversation_subentries[0]
@ -1320,8 +1359,8 @@ async def test_migrate_entry_from_v2_2(hass: HomeAssistant) -> None:
assert entry.version == 2 assert entry.version == 2
assert entry.minor_version == 4 assert entry.minor_version == 4
# Check we now have conversation, tts and ai_task_data subentries # Check we now have conversation, tts, stt, and ai_task_data subentries
assert len(entry.subentries) == 3 assert len(entry.subentries) == 4
subentries = { subentries = {
subentry.subentry_type: subentry for subentry in entry.subentries.values() subentry.subentry_type: subentry for subentry in entry.subentries.values()
@ -1336,6 +1375,12 @@ async def test_migrate_entry_from_v2_2(hass: HomeAssistant) -> None:
assert ai_task_subentry.title == DEFAULT_AI_TASK_NAME assert ai_task_subentry.title == DEFAULT_AI_TASK_NAME
assert ai_task_subentry.data == RECOMMENDED_AI_TASK_OPTIONS assert ai_task_subentry.data == RECOMMENDED_AI_TASK_OPTIONS
# Find and verify the stt subentry
ai_task_subentry = subentries["stt"]
assert ai_task_subentry is not None
assert ai_task_subentry.title == DEFAULT_STT_NAME
assert ai_task_subentry.data == RECOMMENDED_STT_OPTIONS
# Verify conversation subentry is still there and unchanged # Verify conversation subentry is still there and unchanged
conversation_subentry = subentries["conversation"] conversation_subentry = subentries["conversation"]
assert conversation_subentry is not None assert conversation_subentry is not None

View File

@ -0,0 +1,303 @@
"""Tests for the Google Generative AI Conversation STT entity."""
from __future__ import annotations
from collections.abc import AsyncIterable, Generator
from unittest.mock import AsyncMock, Mock, patch
from google.genai import types
import pytest
from homeassistant.components import stt
from homeassistant.components.google_generative_ai_conversation.const import (
CONF_CHAT_MODEL,
CONF_PROMPT,
DEFAULT_STT_PROMPT,
DOMAIN,
RECOMMENDED_STT_MODEL,
)
from homeassistant.config_entries import ConfigSubentry
from homeassistant.const import CONF_API_KEY
from homeassistant.core import HomeAssistant
from . import API_ERROR_500, CLIENT_ERROR_BAD_REQUEST
from tests.common import MockConfigEntry
TEST_CHAT_MODEL = "models/gemini-2.5-flash"
TEST_PROMPT = "Please transcribe the audio."
async def _async_get_audio_stream(data: bytes) -> AsyncIterable[bytes]:
"""Yield the audio data."""
yield data
@pytest.fixture
def mock_genai_client() -> Generator[AsyncMock]:
"""Mock genai.Client."""
client = Mock()
client.aio.models.get = AsyncMock()
client.aio.models.generate_content = AsyncMock(
return_value=types.GenerateContentResponse(
candidates=[
{
"content": {
"parts": [{"text": "This is a test transcription."}],
"role": "model",
}
}
]
)
)
with patch(
"homeassistant.components.google_generative_ai_conversation.Client",
return_value=client,
) as mock_client:
yield mock_client.return_value
@pytest.fixture
async def setup_integration(
hass: HomeAssistant,
mock_genai_client: AsyncMock,
) -> None:
"""Set up the test environment."""
config_entry = MockConfigEntry(
domain=DOMAIN, data={CONF_API_KEY: "bla"}, version=2, minor_version=1
)
config_entry.add_to_hass(hass)
sub_entry = ConfigSubentry(
data={
CONF_CHAT_MODEL: TEST_CHAT_MODEL,
CONF_PROMPT: TEST_PROMPT,
},
subentry_type="stt",
title="Google AI STT",
unique_id=None,
)
config_entry.runtime_data = mock_genai_client
hass.config_entries.async_add_subentry(config_entry, sub_entry)
await hass.config_entries.async_setup(config_entry.entry_id)
await hass.async_block_till_done()
@pytest.mark.usefixtures("setup_integration")
async def test_stt_entity_properties(hass: HomeAssistant) -> None:
"""Test STT entity properties."""
entity: stt.SpeechToTextEntity = hass.data[stt.DOMAIN].get_entity(
"stt.google_ai_stt"
)
assert entity is not None
assert isinstance(entity.supported_languages, list)
assert stt.AudioFormats.WAV in entity.supported_formats
assert stt.AudioFormats.OGG in entity.supported_formats
assert stt.AudioCodecs.PCM in entity.supported_codecs
assert stt.AudioCodecs.OPUS in entity.supported_codecs
assert stt.AudioBitRates.BITRATE_16 in entity.supported_bit_rates
assert stt.AudioSampleRates.SAMPLERATE_16000 in entity.supported_sample_rates
assert stt.AudioChannels.CHANNEL_MONO in entity.supported_channels
@pytest.mark.parametrize(
("audio_format", "call_convert_to_wav"),
[
(stt.AudioFormats.WAV, True),
(stt.AudioFormats.OGG, False),
],
)
@pytest.mark.usefixtures("setup_integration")
async def test_stt_process_audio_stream_success(
hass: HomeAssistant,
mock_genai_client: AsyncMock,
audio_format: stt.AudioFormats,
call_convert_to_wav: bool,
) -> None:
"""Test STT processing audio stream successfully."""
entity = hass.data[stt.DOMAIN].get_entity("stt.google_ai_stt")
metadata = stt.SpeechMetadata(
language="en-US",
format=audio_format,
codec=stt.AudioCodecs.PCM,
bit_rate=stt.AudioBitRates.BITRATE_16,
sample_rate=stt.AudioSampleRates.SAMPLERATE_16000,
channel=stt.AudioChannels.CHANNEL_MONO,
)
audio_stream = _async_get_audio_stream(b"test_audio_bytes")
with patch(
"homeassistant.components.google_generative_ai_conversation.stt.convert_to_wav",
return_value=b"converted_wav_bytes",
) as mock_convert_to_wav:
result = await entity.async_process_audio_stream(metadata, audio_stream)
assert result.result == stt.SpeechResultState.SUCCESS
assert result.text == "This is a test transcription."
if call_convert_to_wav:
mock_convert_to_wav.assert_called_once_with(
b"test_audio_bytes", "audio/L16;rate=16000"
)
else:
mock_convert_to_wav.assert_not_called()
mock_genai_client.aio.models.generate_content.assert_called_once()
call_args = mock_genai_client.aio.models.generate_content.call_args
assert call_args.kwargs["model"] == TEST_CHAT_MODEL
contents = call_args.kwargs["contents"]
assert contents[0] == TEST_PROMPT
assert isinstance(contents[1], types.Part)
assert contents[1].inline_data.mime_type == f"audio/{audio_format.value}"
if call_convert_to_wav:
assert contents[1].inline_data.data == b"converted_wav_bytes"
else:
assert contents[1].inline_data.data == b"test_audio_bytes"
@pytest.mark.parametrize(
"side_effect",
[
API_ERROR_500,
CLIENT_ERROR_BAD_REQUEST,
ValueError("Test value error"),
],
)
@pytest.mark.usefixtures("setup_integration")
async def test_stt_process_audio_stream_api_error(
hass: HomeAssistant,
mock_genai_client: AsyncMock,
side_effect: Exception,
) -> None:
"""Test STT processing audio stream with API errors."""
entity = hass.data[stt.DOMAIN].get_entity("stt.google_ai_stt")
mock_genai_client.aio.models.generate_content.side_effect = side_effect
metadata = stt.SpeechMetadata(
language="en-US",
format=stt.AudioFormats.OGG,
codec=stt.AudioCodecs.OPUS,
bit_rate=stt.AudioBitRates.BITRATE_16,
sample_rate=stt.AudioSampleRates.SAMPLERATE_16000,
channel=stt.AudioChannels.CHANNEL_MONO,
)
audio_stream = _async_get_audio_stream(b"test_audio_bytes")
result = await entity.async_process_audio_stream(metadata, audio_stream)
assert result.result == stt.SpeechResultState.ERROR
assert result.text is None
@pytest.mark.usefixtures("setup_integration")
async def test_stt_process_audio_stream_empty_response(
hass: HomeAssistant,
mock_genai_client: AsyncMock,
) -> None:
"""Test STT processing with an empty response from the API."""
entity = hass.data[stt.DOMAIN].get_entity("stt.google_ai_stt")
mock_genai_client.aio.models.generate_content.return_value = (
types.GenerateContentResponse(candidates=[])
)
metadata = stt.SpeechMetadata(
language="en-US",
format=stt.AudioFormats.OGG,
codec=stt.AudioCodecs.OPUS,
bit_rate=stt.AudioBitRates.BITRATE_16,
sample_rate=stt.AudioSampleRates.SAMPLERATE_16000,
channel=stt.AudioChannels.CHANNEL_MONO,
)
audio_stream = _async_get_audio_stream(b"test_audio_bytes")
result = await entity.async_process_audio_stream(metadata, audio_stream)
assert result.result == stt.SpeechResultState.ERROR
assert result.text is None
@pytest.mark.usefixtures("mock_genai_client")
async def test_stt_uses_default_prompt(
hass: HomeAssistant,
mock_genai_client: AsyncMock,
) -> None:
"""Test that the default prompt is used if none is configured."""
config_entry = MockConfigEntry(
domain=DOMAIN, data={CONF_API_KEY: "bla"}, version=2, minor_version=1
)
config_entry.add_to_hass(hass)
config_entry.runtime_data = mock_genai_client
# Subentry with no prompt
sub_entry = ConfigSubentry(
data={CONF_CHAT_MODEL: TEST_CHAT_MODEL},
subentry_type="stt",
title="Google AI STT",
unique_id=None,
)
hass.config_entries.async_add_subentry(config_entry, sub_entry)
await hass.config_entries.async_setup(config_entry.entry_id)
await hass.async_block_till_done()
entity = hass.data[stt.DOMAIN].get_entity("stt.google_ai_stt")
metadata = stt.SpeechMetadata(
language="en-US",
format=stt.AudioFormats.OGG,
codec=stt.AudioCodecs.OPUS,
bit_rate=stt.AudioBitRates.BITRATE_16,
sample_rate=stt.AudioSampleRates.SAMPLERATE_16000,
channel=stt.AudioChannels.CHANNEL_MONO,
)
audio_stream = _async_get_audio_stream(b"test_audio_bytes")
await entity.async_process_audio_stream(metadata, audio_stream)
call_args = mock_genai_client.aio.models.generate_content.call_args
contents = call_args.kwargs["contents"]
assert contents[0] == DEFAULT_STT_PROMPT
@pytest.mark.usefixtures("mock_genai_client")
async def test_stt_uses_default_model(
hass: HomeAssistant,
mock_genai_client: AsyncMock,
) -> None:
"""Test that the default model is used if none is configured."""
config_entry = MockConfigEntry(
domain=DOMAIN, data={CONF_API_KEY: "bla"}, version=2, minor_version=1
)
config_entry.add_to_hass(hass)
config_entry.runtime_data = mock_genai_client
# Subentry with no model
sub_entry = ConfigSubentry(
data={CONF_PROMPT: TEST_PROMPT},
subentry_type="stt",
title="Google AI STT",
unique_id=None,
)
hass.config_entries.async_add_subentry(config_entry, sub_entry)
await hass.config_entries.async_setup(config_entry.entry_id)
await hass.async_block_till_done()
entity = hass.data[stt.DOMAIN].get_entity("stt.google_ai_stt")
metadata = stt.SpeechMetadata(
language="en-US",
format=stt.AudioFormats.OGG,
codec=stt.AudioCodecs.OPUS,
bit_rate=stt.AudioBitRates.BITRATE_16,
sample_rate=stt.AudioSampleRates.SAMPLERATE_16000,
channel=stt.AudioChannels.CHANNEL_MONO,
)
audio_stream = _async_get_audio_stream(b"test_audio_bytes")
await entity.async_process_audio_stream(metadata, audio_stream)
call_args = mock_genai_client.aio.models.generate_content.call_args
assert call_args.kwargs["model"] == RECOMMENDED_STT_MODEL