mirror of
https://github.com/home-assistant/core.git
synced 2025-07-19 03:07:37 +00:00
Add Google AI STT (#147563)
This commit is contained in:
parent
26a9af7371
commit
02a11638b3
@ -36,12 +36,14 @@ from homeassistant.helpers.typing import ConfigType
|
||||
from .const import (
|
||||
CONF_PROMPT,
|
||||
DEFAULT_AI_TASK_NAME,
|
||||
DEFAULT_STT_NAME,
|
||||
DEFAULT_TITLE,
|
||||
DEFAULT_TTS_NAME,
|
||||
DOMAIN,
|
||||
LOGGER,
|
||||
RECOMMENDED_AI_TASK_OPTIONS,
|
||||
RECOMMENDED_CHAT_MODEL,
|
||||
RECOMMENDED_STT_OPTIONS,
|
||||
RECOMMENDED_TTS_OPTIONS,
|
||||
TIMEOUT_MILLIS,
|
||||
)
|
||||
@ -55,6 +57,7 @@ CONFIG_SCHEMA = cv.config_entry_only_config_schema(DOMAIN)
|
||||
PLATFORMS = (
|
||||
Platform.AI_TASK,
|
||||
Platform.CONVERSATION,
|
||||
Platform.STT,
|
||||
Platform.TTS,
|
||||
)
|
||||
|
||||
@ -301,7 +304,7 @@ async def async_migrate_integration(hass: HomeAssistant) -> None:
|
||||
if not use_existing:
|
||||
await hass.config_entries.async_remove(entry.entry_id)
|
||||
else:
|
||||
_add_ai_task_subentry(hass, entry)
|
||||
_add_ai_task_and_stt_subentries(hass, entry)
|
||||
hass.config_entries.async_update_entry(
|
||||
entry,
|
||||
title=DEFAULT_TITLE,
|
||||
@ -350,8 +353,7 @@ async def async_migrate_entry(
|
||||
hass.config_entries.async_update_entry(entry, minor_version=2)
|
||||
|
||||
if entry.version == 2 and entry.minor_version == 2:
|
||||
# Add AI Task subentry with default options
|
||||
_add_ai_task_subentry(hass, entry)
|
||||
_add_ai_task_and_stt_subentries(hass, entry)
|
||||
hass.config_entries.async_update_entry(entry, minor_version=3)
|
||||
|
||||
if entry.version == 2 and entry.minor_version == 3:
|
||||
@ -393,10 +395,10 @@ async def async_migrate_entry(
|
||||
return True
|
||||
|
||||
|
||||
def _add_ai_task_subentry(
|
||||
def _add_ai_task_and_stt_subentries(
|
||||
hass: HomeAssistant, entry: GoogleGenerativeAIConfigEntry
|
||||
) -> None:
|
||||
"""Add AI Task subentry to the config entry."""
|
||||
"""Add AI Task and STT subentries to the config entry."""
|
||||
hass.config_entries.async_add_subentry(
|
||||
entry,
|
||||
ConfigSubentry(
|
||||
@ -406,3 +408,12 @@ def _add_ai_task_subentry(
|
||||
unique_id=None,
|
||||
),
|
||||
)
|
||||
hass.config_entries.async_add_subentry(
|
||||
entry,
|
||||
ConfigSubentry(
|
||||
data=MappingProxyType(RECOMMENDED_STT_OPTIONS),
|
||||
subentry_type="stt",
|
||||
title=DEFAULT_STT_NAME,
|
||||
unique_id=None,
|
||||
),
|
||||
)
|
||||
|
@ -49,6 +49,8 @@ from .const import (
|
||||
CONF_USE_GOOGLE_SEARCH_TOOL,
|
||||
DEFAULT_AI_TASK_NAME,
|
||||
DEFAULT_CONVERSATION_NAME,
|
||||
DEFAULT_STT_NAME,
|
||||
DEFAULT_STT_PROMPT,
|
||||
DEFAULT_TITLE,
|
||||
DEFAULT_TTS_NAME,
|
||||
DOMAIN,
|
||||
@ -57,6 +59,8 @@ from .const import (
|
||||
RECOMMENDED_CONVERSATION_OPTIONS,
|
||||
RECOMMENDED_HARM_BLOCK_THRESHOLD,
|
||||
RECOMMENDED_MAX_TOKENS,
|
||||
RECOMMENDED_STT_MODEL,
|
||||
RECOMMENDED_STT_OPTIONS,
|
||||
RECOMMENDED_TEMPERATURE,
|
||||
RECOMMENDED_TOP_K,
|
||||
RECOMMENDED_TOP_P,
|
||||
@ -144,6 +148,12 @@ class GoogleGenerativeAIConfigFlow(ConfigFlow, domain=DOMAIN):
|
||||
"title": DEFAULT_AI_TASK_NAME,
|
||||
"unique_id": None,
|
||||
},
|
||||
{
|
||||
"subentry_type": "stt",
|
||||
"data": RECOMMENDED_STT_OPTIONS,
|
||||
"title": DEFAULT_STT_NAME,
|
||||
"unique_id": None,
|
||||
},
|
||||
],
|
||||
)
|
||||
return self.async_show_form(
|
||||
@ -191,6 +201,7 @@ class GoogleGenerativeAIConfigFlow(ConfigFlow, domain=DOMAIN):
|
||||
"""Return subentries supported by this integration."""
|
||||
return {
|
||||
"conversation": LLMSubentryFlowHandler,
|
||||
"stt": LLMSubentryFlowHandler,
|
||||
"tts": LLMSubentryFlowHandler,
|
||||
"ai_task_data": LLMSubentryFlowHandler,
|
||||
}
|
||||
@ -228,6 +239,8 @@ class LLMSubentryFlowHandler(ConfigSubentryFlow):
|
||||
options = RECOMMENDED_TTS_OPTIONS.copy()
|
||||
elif self._subentry_type == "ai_task_data":
|
||||
options = RECOMMENDED_AI_TASK_OPTIONS.copy()
|
||||
elif self._subentry_type == "stt":
|
||||
options = RECOMMENDED_STT_OPTIONS.copy()
|
||||
else:
|
||||
options = RECOMMENDED_CONVERSATION_OPTIONS.copy()
|
||||
else:
|
||||
@ -304,6 +317,8 @@ async def google_generative_ai_config_option_schema(
|
||||
default_name = DEFAULT_TTS_NAME
|
||||
elif subentry_type == "ai_task_data":
|
||||
default_name = DEFAULT_AI_TASK_NAME
|
||||
elif subentry_type == "stt":
|
||||
default_name = DEFAULT_STT_NAME
|
||||
else:
|
||||
default_name = DEFAULT_CONVERSATION_NAME
|
||||
schema: dict[vol.Required | vol.Optional, Any] = {
|
||||
@ -331,6 +346,17 @@ async def google_generative_ai_config_option_schema(
|
||||
),
|
||||
}
|
||||
)
|
||||
elif subentry_type == "stt":
|
||||
schema.update(
|
||||
{
|
||||
vol.Optional(
|
||||
CONF_PROMPT,
|
||||
description={
|
||||
"suggested_value": options.get(CONF_PROMPT, DEFAULT_STT_PROMPT)
|
||||
},
|
||||
): TemplateSelector(),
|
||||
}
|
||||
)
|
||||
|
||||
schema.update(
|
||||
{
|
||||
@ -388,6 +414,8 @@ async def google_generative_ai_config_option_schema(
|
||||
|
||||
if subentry_type == "tts":
|
||||
default_model = RECOMMENDED_TTS_MODEL
|
||||
elif subentry_type == "stt":
|
||||
default_model = RECOMMENDED_STT_MODEL
|
||||
else:
|
||||
default_model = RECOMMENDED_CHAT_MODEL
|
||||
|
||||
|
@ -5,18 +5,23 @@ import logging
|
||||
from homeassistant.const import CONF_LLM_HASS_API
|
||||
from homeassistant.helpers import llm
|
||||
|
||||
LOGGER = logging.getLogger(__package__)
|
||||
|
||||
DOMAIN = "google_generative_ai_conversation"
|
||||
DEFAULT_TITLE = "Google Generative AI"
|
||||
LOGGER = logging.getLogger(__package__)
|
||||
CONF_PROMPT = "prompt"
|
||||
|
||||
DEFAULT_CONVERSATION_NAME = "Google AI Conversation"
|
||||
DEFAULT_STT_NAME = "Google AI STT"
|
||||
DEFAULT_TTS_NAME = "Google AI TTS"
|
||||
DEFAULT_AI_TASK_NAME = "Google AI Task"
|
||||
|
||||
CONF_PROMPT = "prompt"
|
||||
DEFAULT_STT_PROMPT = "Transcribe the attached audio"
|
||||
|
||||
CONF_RECOMMENDED = "recommended"
|
||||
CONF_CHAT_MODEL = "chat_model"
|
||||
RECOMMENDED_CHAT_MODEL = "models/gemini-2.5-flash"
|
||||
RECOMMENDED_STT_MODEL = RECOMMENDED_CHAT_MODEL
|
||||
RECOMMENDED_TTS_MODEL = "models/gemini-2.5-flash-preview-tts"
|
||||
CONF_TEMPERATURE = "temperature"
|
||||
RECOMMENDED_TEMPERATURE = 1.0
|
||||
@ -43,6 +48,11 @@ RECOMMENDED_CONVERSATION_OPTIONS = {
|
||||
CONF_RECOMMENDED: True,
|
||||
}
|
||||
|
||||
RECOMMENDED_STT_OPTIONS = {
|
||||
CONF_PROMPT: DEFAULT_STT_PROMPT,
|
||||
CONF_RECOMMENDED: True,
|
||||
}
|
||||
|
||||
RECOMMENDED_TTS_OPTIONS = {
|
||||
CONF_RECOMMENDED: True,
|
||||
}
|
||||
|
@ -61,6 +61,38 @@
|
||||
"invalid_google_search_option": "Google Search can only be enabled if nothing is selected in the \"Control Home Assistant\" setting."
|
||||
}
|
||||
},
|
||||
"stt": {
|
||||
"initiate_flow": {
|
||||
"user": "Add Speech-to-Text service",
|
||||
"reconfigure": "Reconfigure Speech-to-Text service"
|
||||
},
|
||||
"entry_type": "Speech-to-Text",
|
||||
"step": {
|
||||
"set_options": {
|
||||
"data": {
|
||||
"name": "[%key:common::config_flow::data::name%]",
|
||||
"recommended": "[%key:component::google_generative_ai_conversation::config_subentries::conversation::step::set_options::data::recommended%]",
|
||||
"prompt": "Instructions",
|
||||
"chat_model": "[%key:common::generic::model%]",
|
||||
"temperature": "[%key:component::google_generative_ai_conversation::config_subentries::conversation::step::set_options::data::temperature%]",
|
||||
"top_p": "[%key:component::google_generative_ai_conversation::config_subentries::conversation::step::set_options::data::top_p%]",
|
||||
"top_k": "[%key:component::google_generative_ai_conversation::config_subentries::conversation::step::set_options::data::top_k%]",
|
||||
"max_tokens": "[%key:component::google_generative_ai_conversation::config_subentries::conversation::step::set_options::data::max_tokens%]",
|
||||
"harassment_block_threshold": "[%key:component::google_generative_ai_conversation::config_subentries::conversation::step::set_options::data::harassment_block_threshold%]",
|
||||
"hate_block_threshold": "[%key:component::google_generative_ai_conversation::config_subentries::conversation::step::set_options::data::hate_block_threshold%]",
|
||||
"sexual_block_threshold": "[%key:component::google_generative_ai_conversation::config_subentries::conversation::step::set_options::data::sexual_block_threshold%]",
|
||||
"dangerous_block_threshold": "[%key:component::google_generative_ai_conversation::config_subentries::conversation::step::set_options::data::dangerous_block_threshold%]"
|
||||
},
|
||||
"data_description": {
|
||||
"prompt": "Instruct how the LLM should transcribe the audio."
|
||||
}
|
||||
}
|
||||
},
|
||||
"abort": {
|
||||
"entry_not_loaded": "[%key:component::google_generative_ai_conversation::config_subentries::conversation::abort::entry_not_loaded%]",
|
||||
"reconfigure_successful": "[%key:common::config_flow::abort::reconfigure_successful%]"
|
||||
}
|
||||
},
|
||||
"tts": {
|
||||
"initiate_flow": {
|
||||
"user": "Add Text-to-Speech service",
|
||||
|
@ -0,0 +1,254 @@
|
||||
"""Speech to text support for Google Generative AI."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import AsyncIterable
|
||||
|
||||
from google.genai.errors import APIError, ClientError
|
||||
from google.genai.types import Part
|
||||
|
||||
from homeassistant.components import stt
|
||||
from homeassistant.config_entries import ConfigEntry, ConfigSubentry
|
||||
from homeassistant.core import HomeAssistant
|
||||
from homeassistant.helpers.entity_platform import AddConfigEntryEntitiesCallback
|
||||
|
||||
from .const import (
|
||||
CONF_CHAT_MODEL,
|
||||
CONF_PROMPT,
|
||||
DEFAULT_STT_PROMPT,
|
||||
LOGGER,
|
||||
RECOMMENDED_STT_MODEL,
|
||||
)
|
||||
from .entity import GoogleGenerativeAILLMBaseEntity
|
||||
from .helpers import convert_to_wav
|
||||
|
||||
|
||||
async def async_setup_entry(
|
||||
hass: HomeAssistant,
|
||||
config_entry: ConfigEntry,
|
||||
async_add_entities: AddConfigEntryEntitiesCallback,
|
||||
) -> None:
|
||||
"""Set up STT entities."""
|
||||
for subentry in config_entry.subentries.values():
|
||||
if subentry.subentry_type != "stt":
|
||||
continue
|
||||
|
||||
async_add_entities(
|
||||
[GoogleGenerativeAISttEntity(config_entry, subentry)],
|
||||
config_subentry_id=subentry.subentry_id,
|
||||
)
|
||||
|
||||
|
||||
class GoogleGenerativeAISttEntity(
|
||||
stt.SpeechToTextEntity, GoogleGenerativeAILLMBaseEntity
|
||||
):
|
||||
"""Google Generative AI speech-to-text entity."""
|
||||
|
||||
def __init__(self, config_entry: ConfigEntry, subentry: ConfigSubentry) -> None:
|
||||
"""Initialize the STT entity."""
|
||||
super().__init__(config_entry, subentry, RECOMMENDED_STT_MODEL)
|
||||
|
||||
@property
|
||||
def supported_languages(self) -> list[str]:
|
||||
"""Return a list of supported languages."""
|
||||
return [
|
||||
"af-ZA",
|
||||
"sq-AL",
|
||||
"am-ET",
|
||||
"ar-DZ",
|
||||
"ar-BH",
|
||||
"ar-EG",
|
||||
"ar-IQ",
|
||||
"ar-IL",
|
||||
"ar-JO",
|
||||
"ar-KW",
|
||||
"ar-LB",
|
||||
"ar-MA",
|
||||
"ar-OM",
|
||||
"ar-QA",
|
||||
"ar-SA",
|
||||
"ar-PS",
|
||||
"ar-TN",
|
||||
"ar-AE",
|
||||
"ar-YE",
|
||||
"hy-AM",
|
||||
"az-AZ",
|
||||
"eu-ES",
|
||||
"bn-BD",
|
||||
"bn-IN",
|
||||
"bs-BA",
|
||||
"bg-BG",
|
||||
"my-MM",
|
||||
"ca-ES",
|
||||
"zh-CN",
|
||||
"zh-TW",
|
||||
"hr-HR",
|
||||
"cs-CZ",
|
||||
"da-DK",
|
||||
"nl-BE",
|
||||
"nl-NL",
|
||||
"en-AU",
|
||||
"en-CA",
|
||||
"en-GH",
|
||||
"en-HK",
|
||||
"en-IN",
|
||||
"en-IE",
|
||||
"en-KE",
|
||||
"en-NZ",
|
||||
"en-NG",
|
||||
"en-PK",
|
||||
"en-PH",
|
||||
"en-SG",
|
||||
"en-ZA",
|
||||
"en-TZ",
|
||||
"en-GB",
|
||||
"en-US",
|
||||
"et-EE",
|
||||
"fil-PH",
|
||||
"fi-FI",
|
||||
"fr-BE",
|
||||
"fr-CA",
|
||||
"fr-FR",
|
||||
"fr-CH",
|
||||
"gl-ES",
|
||||
"ka-GE",
|
||||
"de-AT",
|
||||
"de-DE",
|
||||
"de-CH",
|
||||
"el-GR",
|
||||
"gu-IN",
|
||||
"iw-IL",
|
||||
"hi-IN",
|
||||
"hu-HU",
|
||||
"is-IS",
|
||||
"id-ID",
|
||||
"it-IT",
|
||||
"it-CH",
|
||||
"ja-JP",
|
||||
"jv-ID",
|
||||
"kn-IN",
|
||||
"kk-KZ",
|
||||
"km-KH",
|
||||
"ko-KR",
|
||||
"lo-LA",
|
||||
"lv-LV",
|
||||
"lt-LT",
|
||||
"mk-MK",
|
||||
"ms-MY",
|
||||
"ml-IN",
|
||||
"mr-IN",
|
||||
"mn-MN",
|
||||
"ne-NP",
|
||||
"no-NO",
|
||||
"fa-IR",
|
||||
"pl-PL",
|
||||
"pt-BR",
|
||||
"pt-PT",
|
||||
"ro-RO",
|
||||
"ru-RU",
|
||||
"sr-RS",
|
||||
"si-LK",
|
||||
"sk-SK",
|
||||
"sl-SI",
|
||||
"es-AR",
|
||||
"es-BO",
|
||||
"es-CL",
|
||||
"es-CO",
|
||||
"es-CR",
|
||||
"es-DO",
|
||||
"es-EC",
|
||||
"es-SV",
|
||||
"es-GT",
|
||||
"es-HN",
|
||||
"es-MX",
|
||||
"es-NI",
|
||||
"es-PA",
|
||||
"es-PY",
|
||||
"es-PE",
|
||||
"es-PR",
|
||||
"es-ES",
|
||||
"es-US",
|
||||
"es-UY",
|
||||
"es-VE",
|
||||
"su-ID",
|
||||
"sw-KE",
|
||||
"sw-TZ",
|
||||
"sv-SE",
|
||||
"ta-IN",
|
||||
"ta-MY",
|
||||
"ta-SG",
|
||||
"ta-LK",
|
||||
"te-IN",
|
||||
"th-TH",
|
||||
"tr-TR",
|
||||
"uk-UA",
|
||||
"ur-IN",
|
||||
"ur-PK",
|
||||
"uz-UZ",
|
||||
"vi-VN",
|
||||
"zu-ZA",
|
||||
]
|
||||
|
||||
@property
|
||||
def supported_formats(self) -> list[stt.AudioFormats]:
|
||||
"""Return a list of supported formats."""
|
||||
# https://ai.google.dev/gemini-api/docs/audio#supported-formats
|
||||
return [stt.AudioFormats.WAV, stt.AudioFormats.OGG]
|
||||
|
||||
@property
|
||||
def supported_codecs(self) -> list[stt.AudioCodecs]:
|
||||
"""Return a list of supported codecs."""
|
||||
return [stt.AudioCodecs.PCM, stt.AudioCodecs.OPUS]
|
||||
|
||||
@property
|
||||
def supported_bit_rates(self) -> list[stt.AudioBitRates]:
|
||||
"""Return a list of supported bit rates."""
|
||||
return [stt.AudioBitRates.BITRATE_16]
|
||||
|
||||
@property
|
||||
def supported_sample_rates(self) -> list[stt.AudioSampleRates]:
|
||||
"""Return a list of supported sample rates."""
|
||||
return [stt.AudioSampleRates.SAMPLERATE_16000]
|
||||
|
||||
@property
|
||||
def supported_channels(self) -> list[stt.AudioChannels]:
|
||||
"""Return a list of supported channels."""
|
||||
# Per https://ai.google.dev/gemini-api/docs/audio
|
||||
# If the audio source contains multiple channels, Gemini combines those channels into a single channel.
|
||||
return [stt.AudioChannels.CHANNEL_MONO]
|
||||
|
||||
async def async_process_audio_stream(
|
||||
self, metadata: stt.SpeechMetadata, stream: AsyncIterable[bytes]
|
||||
) -> stt.SpeechResult:
|
||||
"""Process an audio stream to STT service."""
|
||||
audio_data = b""
|
||||
async for chunk in stream:
|
||||
audio_data += chunk
|
||||
if metadata.format == stt.AudioFormats.WAV:
|
||||
audio_data = convert_to_wav(
|
||||
audio_data,
|
||||
f"audio/L{metadata.bit_rate.value};rate={metadata.sample_rate.value}",
|
||||
)
|
||||
|
||||
try:
|
||||
response = await self._genai_client.aio.models.generate_content(
|
||||
model=self.subentry.data.get(CONF_CHAT_MODEL, RECOMMENDED_STT_MODEL),
|
||||
contents=[
|
||||
self.subentry.data.get(CONF_PROMPT, DEFAULT_STT_PROMPT),
|
||||
Part.from_bytes(
|
||||
data=audio_data,
|
||||
mime_type=f"audio/{metadata.format.value}",
|
||||
),
|
||||
],
|
||||
config=self.create_generate_content_config(),
|
||||
)
|
||||
except (APIError, ClientError, ValueError) as err:
|
||||
LOGGER.error("Error during STT: %s", err)
|
||||
else:
|
||||
if response.text:
|
||||
return stt.SpeechResult(
|
||||
response.text,
|
||||
stt.SpeechResultState.SUCCESS,
|
||||
)
|
||||
|
||||
return stt.SpeechResult(None, stt.SpeechResultState.ERROR)
|
@ -9,6 +9,7 @@ from homeassistant.components.google_generative_ai_conversation.const import (
|
||||
CONF_USE_GOOGLE_SEARCH_TOOL,
|
||||
DEFAULT_AI_TASK_NAME,
|
||||
DEFAULT_CONVERSATION_NAME,
|
||||
DEFAULT_STT_NAME,
|
||||
DEFAULT_TTS_NAME,
|
||||
)
|
||||
from homeassistant.config_entries import ConfigEntry
|
||||
@ -39,6 +40,13 @@ def mock_config_entry(hass: HomeAssistant) -> MockConfigEntry:
|
||||
"subentry_id": "ulid-conversation",
|
||||
"unique_id": None,
|
||||
},
|
||||
{
|
||||
"data": {},
|
||||
"subentry_type": "stt",
|
||||
"title": DEFAULT_STT_NAME,
|
||||
"subentry_id": "ulid-stt",
|
||||
"unique_id": None,
|
||||
},
|
||||
{
|
||||
"data": {},
|
||||
"subentry_type": "tts",
|
||||
|
@ -34,6 +34,14 @@
|
||||
'title': 'Google AI Conversation',
|
||||
'unique_id': None,
|
||||
}),
|
||||
'ulid-stt': dict({
|
||||
'data': dict({
|
||||
}),
|
||||
'subentry_id': 'ulid-stt',
|
||||
'subentry_type': 'stt',
|
||||
'title': 'Google AI STT',
|
||||
'unique_id': None,
|
||||
}),
|
||||
'ulid-tts': dict({
|
||||
'data': dict({
|
||||
}),
|
||||
|
@ -32,6 +32,37 @@
|
||||
'sw_version': None,
|
||||
'via_device_id': None,
|
||||
}),
|
||||
DeviceRegistryEntrySnapshot({
|
||||
'area_id': None,
|
||||
'config_entries': <ANY>,
|
||||
'config_entries_subentries': <ANY>,
|
||||
'configuration_url': None,
|
||||
'connections': set({
|
||||
}),
|
||||
'disabled_by': None,
|
||||
'entry_type': <DeviceEntryType.SERVICE: 'service'>,
|
||||
'hw_version': None,
|
||||
'id': <ANY>,
|
||||
'identifiers': set({
|
||||
tuple(
|
||||
'google_generative_ai_conversation',
|
||||
'ulid-stt',
|
||||
),
|
||||
}),
|
||||
'is_new': False,
|
||||
'labels': set({
|
||||
}),
|
||||
'manufacturer': 'Google',
|
||||
'model': 'gemini-2.5-flash',
|
||||
'model_id': None,
|
||||
'name': 'Google AI STT',
|
||||
'name_by_user': None,
|
||||
'primary_config_entry': <ANY>,
|
||||
'serial_number': None,
|
||||
'suggested_area': None,
|
||||
'sw_version': None,
|
||||
'via_device_id': None,
|
||||
}),
|
||||
DeviceRegistryEntrySnapshot({
|
||||
'area_id': None,
|
||||
'config_entries': <ANY>,
|
||||
|
@ -1,5 +1,6 @@
|
||||
"""Test the Google Generative AI Conversation config flow."""
|
||||
|
||||
from typing import Any
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
import pytest
|
||||
@ -21,6 +22,7 @@ from homeassistant.components.google_generative_ai_conversation.const import (
|
||||
CONF_USE_GOOGLE_SEARCH_TOOL,
|
||||
DEFAULT_AI_TASK_NAME,
|
||||
DEFAULT_CONVERSATION_NAME,
|
||||
DEFAULT_STT_NAME,
|
||||
DEFAULT_TTS_NAME,
|
||||
DOMAIN,
|
||||
RECOMMENDED_AI_TASK_OPTIONS,
|
||||
@ -28,8 +30,11 @@ from homeassistant.components.google_generative_ai_conversation.const import (
|
||||
RECOMMENDED_CONVERSATION_OPTIONS,
|
||||
RECOMMENDED_HARM_BLOCK_THRESHOLD,
|
||||
RECOMMENDED_MAX_TOKENS,
|
||||
RECOMMENDED_STT_MODEL,
|
||||
RECOMMENDED_STT_OPTIONS,
|
||||
RECOMMENDED_TOP_K,
|
||||
RECOMMENDED_TOP_P,
|
||||
RECOMMENDED_TTS_MODEL,
|
||||
RECOMMENDED_TTS_OPTIONS,
|
||||
RECOMMENDED_USE_GOOGLE_SEARCH_TOOL,
|
||||
)
|
||||
@ -64,11 +69,17 @@ def get_models_pager():
|
||||
)
|
||||
model_15_pro.name = "models/gemini-1.5-pro-latest"
|
||||
|
||||
model_25_flash_tts = Mock(
|
||||
supported_actions=["generateContent"],
|
||||
)
|
||||
model_25_flash_tts.name = "models/gemini-2.5-flash-preview-tts"
|
||||
|
||||
async def models_pager():
|
||||
yield model_25_flash
|
||||
yield model_20_flash
|
||||
yield model_15_flash
|
||||
yield model_15_pro
|
||||
yield model_25_flash_tts
|
||||
|
||||
return models_pager()
|
||||
|
||||
@ -129,6 +140,12 @@ async def test_form(hass: HomeAssistant) -> None:
|
||||
"title": DEFAULT_AI_TASK_NAME,
|
||||
"unique_id": None,
|
||||
},
|
||||
{
|
||||
"subentry_type": "stt",
|
||||
"data": RECOMMENDED_STT_OPTIONS,
|
||||
"title": DEFAULT_STT_NAME,
|
||||
"unique_id": None,
|
||||
},
|
||||
]
|
||||
assert len(mock_setup_entry.mock_calls) == 1
|
||||
|
||||
@ -157,22 +174,35 @@ async def test_duplicate_entry(hass: HomeAssistant) -> None:
|
||||
assert result["reason"] == "already_configured"
|
||||
|
||||
|
||||
async def test_creating_conversation_subentry(
|
||||
@pytest.mark.parametrize(
|
||||
("subentry_type", "options"),
|
||||
[
|
||||
("conversation", RECOMMENDED_CONVERSATION_OPTIONS),
|
||||
("stt", RECOMMENDED_STT_OPTIONS),
|
||||
("tts", RECOMMENDED_TTS_OPTIONS),
|
||||
("ai_task_data", RECOMMENDED_AI_TASK_OPTIONS),
|
||||
],
|
||||
)
|
||||
async def test_creating_subentry(
|
||||
hass: HomeAssistant,
|
||||
mock_init_component: None,
|
||||
mock_config_entry: MockConfigEntry,
|
||||
subentry_type: str,
|
||||
options: dict[str, Any],
|
||||
) -> None:
|
||||
"""Test creating a conversation subentry."""
|
||||
"""Test creating a subentry."""
|
||||
old_subentries = set(mock_config_entry.subentries)
|
||||
|
||||
with patch(
|
||||
"google.genai.models.AsyncModels.list",
|
||||
return_value=get_models_pager(),
|
||||
):
|
||||
result = await hass.config_entries.subentries.async_init(
|
||||
(mock_config_entry.entry_id, "conversation"),
|
||||
(mock_config_entry.entry_id, subentry_type),
|
||||
context={"source": config_entries.SOURCE_USER},
|
||||
)
|
||||
|
||||
assert result["type"] is FlowResultType.FORM
|
||||
assert result["type"] is FlowResultType.FORM, result
|
||||
assert result["step_id"] == "set_options"
|
||||
assert not result["errors"]
|
||||
|
||||
@ -182,31 +212,117 @@ async def test_creating_conversation_subentry(
|
||||
):
|
||||
result2 = await hass.config_entries.subentries.async_configure(
|
||||
result["flow_id"],
|
||||
{CONF_NAME: "Mock name", **RECOMMENDED_CONVERSATION_OPTIONS},
|
||||
result["data_schema"]({CONF_NAME: "Mock name", **options}),
|
||||
)
|
||||
await hass.async_block_till_done()
|
||||
|
||||
expected_options = options.copy()
|
||||
if CONF_PROMPT in expected_options:
|
||||
expected_options[CONF_PROMPT] = expected_options[CONF_PROMPT].strip()
|
||||
assert result2["type"] is FlowResultType.CREATE_ENTRY
|
||||
assert result2["title"] == "Mock name"
|
||||
assert result2["data"] == expected_options
|
||||
|
||||
processed_options = RECOMMENDED_CONVERSATION_OPTIONS.copy()
|
||||
processed_options[CONF_PROMPT] = processed_options[CONF_PROMPT].strip()
|
||||
assert len(mock_config_entry.subentries) == len(old_subentries) + 1
|
||||
|
||||
assert result2["data"] == processed_options
|
||||
new_subentry_id = list(set(mock_config_entry.subentries) - old_subentries)[0]
|
||||
new_subentry = mock_config_entry.subentries[new_subentry_id]
|
||||
|
||||
assert new_subentry.subentry_type == subentry_type
|
||||
assert new_subentry.data == expected_options
|
||||
assert new_subentry.title == "Mock name"
|
||||
|
||||
|
||||
async def test_creating_tts_subentry(
|
||||
@pytest.mark.parametrize(
|
||||
("subentry_type", "recommended_model", "options"),
|
||||
[
|
||||
(
|
||||
"conversation",
|
||||
RECOMMENDED_CHAT_MODEL,
|
||||
{
|
||||
CONF_PROMPT: "You are Mario",
|
||||
CONF_LLM_HASS_API: ["assist"],
|
||||
CONF_RECOMMENDED: False,
|
||||
CONF_CHAT_MODEL: RECOMMENDED_CHAT_MODEL,
|
||||
CONF_TEMPERATURE: 1.0,
|
||||
CONF_TOP_P: 1.0,
|
||||
CONF_TOP_K: 1,
|
||||
CONF_MAX_TOKENS: 1024,
|
||||
CONF_HARASSMENT_BLOCK_THRESHOLD: "BLOCK_MEDIUM_AND_ABOVE",
|
||||
CONF_HATE_BLOCK_THRESHOLD: "BLOCK_MEDIUM_AND_ABOVE",
|
||||
CONF_SEXUAL_BLOCK_THRESHOLD: "BLOCK_MEDIUM_AND_ABOVE",
|
||||
CONF_DANGEROUS_BLOCK_THRESHOLD: "BLOCK_MEDIUM_AND_ABOVE",
|
||||
CONF_USE_GOOGLE_SEARCH_TOOL: RECOMMENDED_USE_GOOGLE_SEARCH_TOOL,
|
||||
},
|
||||
),
|
||||
(
|
||||
"stt",
|
||||
RECOMMENDED_STT_MODEL,
|
||||
{
|
||||
CONF_PROMPT: "Transcribe this",
|
||||
CONF_RECOMMENDED: False,
|
||||
CONF_CHAT_MODEL: RECOMMENDED_STT_MODEL,
|
||||
CONF_TEMPERATURE: 1.0,
|
||||
CONF_TOP_P: 1.0,
|
||||
CONF_TOP_K: 1,
|
||||
CONF_MAX_TOKENS: 1024,
|
||||
CONF_HARASSMENT_BLOCK_THRESHOLD: "BLOCK_MEDIUM_AND_ABOVE",
|
||||
CONF_HATE_BLOCK_THRESHOLD: "BLOCK_MEDIUM_AND_ABOVE",
|
||||
CONF_SEXUAL_BLOCK_THRESHOLD: "BLOCK_MEDIUM_AND_ABOVE",
|
||||
CONF_DANGEROUS_BLOCK_THRESHOLD: "BLOCK_MEDIUM_AND_ABOVE",
|
||||
},
|
||||
),
|
||||
(
|
||||
"tts",
|
||||
RECOMMENDED_TTS_MODEL,
|
||||
{
|
||||
CONF_RECOMMENDED: False,
|
||||
CONF_CHAT_MODEL: RECOMMENDED_TTS_MODEL,
|
||||
CONF_TEMPERATURE: 1.0,
|
||||
CONF_TOP_P: 1.0,
|
||||
CONF_TOP_K: 1,
|
||||
CONF_MAX_TOKENS: 1024,
|
||||
CONF_HARASSMENT_BLOCK_THRESHOLD: "BLOCK_MEDIUM_AND_ABOVE",
|
||||
CONF_HATE_BLOCK_THRESHOLD: "BLOCK_MEDIUM_AND_ABOVE",
|
||||
CONF_SEXUAL_BLOCK_THRESHOLD: "BLOCK_MEDIUM_AND_ABOVE",
|
||||
CONF_DANGEROUS_BLOCK_THRESHOLD: "BLOCK_MEDIUM_AND_ABOVE",
|
||||
},
|
||||
),
|
||||
(
|
||||
"ai_task_data",
|
||||
RECOMMENDED_CHAT_MODEL,
|
||||
{
|
||||
CONF_RECOMMENDED: False,
|
||||
CONF_CHAT_MODEL: RECOMMENDED_CHAT_MODEL,
|
||||
CONF_TEMPERATURE: 1.0,
|
||||
CONF_TOP_P: 1.0,
|
||||
CONF_TOP_K: 1,
|
||||
CONF_MAX_TOKENS: 1024,
|
||||
CONF_HARASSMENT_BLOCK_THRESHOLD: "BLOCK_MEDIUM_AND_ABOVE",
|
||||
CONF_HATE_BLOCK_THRESHOLD: "BLOCK_MEDIUM_AND_ABOVE",
|
||||
CONF_SEXUAL_BLOCK_THRESHOLD: "BLOCK_MEDIUM_AND_ABOVE",
|
||||
CONF_DANGEROUS_BLOCK_THRESHOLD: "BLOCK_MEDIUM_AND_ABOVE",
|
||||
},
|
||||
),
|
||||
],
|
||||
)
|
||||
async def test_creating_subentry_custom_options(
|
||||
hass: HomeAssistant,
|
||||
mock_init_component: None,
|
||||
mock_config_entry: MockConfigEntry,
|
||||
subentry_type: str,
|
||||
recommended_model: str,
|
||||
options: dict[str, Any],
|
||||
) -> None:
|
||||
"""Test creating a TTS subentry."""
|
||||
"""Test creating a subentry with custom options."""
|
||||
old_subentries = set(mock_config_entry.subentries)
|
||||
|
||||
with patch(
|
||||
"google.genai.models.AsyncModels.list",
|
||||
return_value=get_models_pager(),
|
||||
):
|
||||
result = await hass.config_entries.subentries.async_init(
|
||||
(mock_config_entry.entry_id, "tts"),
|
||||
(mock_config_entry.entry_id, subentry_type),
|
||||
context={"source": config_entries.SOURCE_USER},
|
||||
)
|
||||
|
||||
@ -214,75 +330,52 @@ async def test_creating_tts_subentry(
|
||||
assert result["step_id"] == "set_options"
|
||||
assert not result["errors"]
|
||||
|
||||
old_subentries = set(mock_config_entry.subentries)
|
||||
|
||||
# Uncheck recommended to show custom options
|
||||
with patch(
|
||||
"google.genai.models.AsyncModels.list",
|
||||
return_value=get_models_pager(),
|
||||
):
|
||||
result2 = await hass.config_entries.subentries.async_configure(
|
||||
result["flow_id"],
|
||||
{CONF_NAME: "Mock TTS", **RECOMMENDED_TTS_OPTIONS},
|
||||
result["data_schema"]({CONF_RECOMMENDED: False}),
|
||||
)
|
||||
await hass.async_block_till_done()
|
||||
assert result2["type"] is FlowResultType.FORM
|
||||
|
||||
assert result2["type"] is FlowResultType.CREATE_ENTRY
|
||||
assert result2["title"] == "Mock TTS"
|
||||
assert result2["data"] == RECOMMENDED_TTS_OPTIONS
|
||||
# Find the schema key for CONF_CHAT_MODEL and check its default
|
||||
schema_dict = result2["data_schema"].schema
|
||||
chat_model_key = next(key for key in schema_dict if key.schema == CONF_CHAT_MODEL)
|
||||
assert chat_model_key.default() == recommended_model
|
||||
models_in_selector = [
|
||||
opt["value"] for opt in schema_dict[chat_model_key].config["options"]
|
||||
]
|
||||
assert recommended_model in models_in_selector
|
||||
|
||||
assert len(mock_config_entry.subentries) == 4
|
||||
|
||||
new_subentry_id = list(set(mock_config_entry.subentries) - old_subentries)[0]
|
||||
new_subentry = mock_config_entry.subentries[new_subentry_id]
|
||||
|
||||
assert new_subentry.subentry_type == "tts"
|
||||
assert new_subentry.data == RECOMMENDED_TTS_OPTIONS
|
||||
assert new_subentry.title == "Mock TTS"
|
||||
|
||||
|
||||
async def test_creating_ai_task_subentry(
|
||||
hass: HomeAssistant,
|
||||
mock_init_component: None,
|
||||
mock_config_entry: MockConfigEntry,
|
||||
) -> None:
|
||||
"""Test creating an AI task subentry."""
|
||||
# Submit the form
|
||||
with patch(
|
||||
"google.genai.models.AsyncModels.list",
|
||||
return_value=get_models_pager(),
|
||||
):
|
||||
result = await hass.config_entries.subentries.async_init(
|
||||
(mock_config_entry.entry_id, "ai_task_data"),
|
||||
context={"source": config_entries.SOURCE_USER},
|
||||
)
|
||||
|
||||
assert result["type"] is FlowResultType.FORM, result
|
||||
assert result["step_id"] == "set_options"
|
||||
assert not result["errors"]
|
||||
|
||||
old_subentries = set(mock_config_entry.subentries)
|
||||
|
||||
with patch(
|
||||
"google.genai.models.AsyncModels.list",
|
||||
return_value=get_models_pager(),
|
||||
):
|
||||
result2 = await hass.config_entries.subentries.async_configure(
|
||||
result["flow_id"],
|
||||
{CONF_NAME: "Mock AI Task", **RECOMMENDED_AI_TASK_OPTIONS},
|
||||
result3 = await hass.config_entries.subentries.async_configure(
|
||||
result2["flow_id"],
|
||||
result2["data_schema"]({CONF_NAME: "Mock name", **options}),
|
||||
)
|
||||
await hass.async_block_till_done()
|
||||
|
||||
assert result2["type"] is FlowResultType.CREATE_ENTRY
|
||||
assert result2["title"] == "Mock AI Task"
|
||||
assert result2["data"] == RECOMMENDED_AI_TASK_OPTIONS
|
||||
expected_options = options.copy()
|
||||
if CONF_PROMPT in expected_options:
|
||||
expected_options[CONF_PROMPT] = expected_options[CONF_PROMPT].strip()
|
||||
assert result3["type"] is FlowResultType.CREATE_ENTRY
|
||||
assert result3["title"] == "Mock name"
|
||||
assert result3["data"] == expected_options
|
||||
|
||||
assert len(mock_config_entry.subentries) == 4
|
||||
assert len(mock_config_entry.subentries) == len(old_subentries) + 1
|
||||
|
||||
new_subentry_id = list(set(mock_config_entry.subentries) - old_subentries)[0]
|
||||
new_subentry = mock_config_entry.subentries[new_subentry_id]
|
||||
|
||||
assert new_subentry.subentry_type == "ai_task_data"
|
||||
assert new_subentry.data == RECOMMENDED_AI_TASK_OPTIONS
|
||||
assert new_subentry.title == "Mock AI Task"
|
||||
assert new_subentry.subentry_type == subentry_type
|
||||
assert new_subentry.data == expected_options
|
||||
assert new_subentry.title == "Mock name"
|
||||
|
||||
|
||||
async def test_creating_conversation_subentry_not_loaded(
|
||||
|
@ -11,11 +11,13 @@ from syrupy.assertion import SnapshotAssertion
|
||||
from homeassistant.components.google_generative_ai_conversation.const import (
|
||||
DEFAULT_AI_TASK_NAME,
|
||||
DEFAULT_CONVERSATION_NAME,
|
||||
DEFAULT_STT_NAME,
|
||||
DEFAULT_TITLE,
|
||||
DEFAULT_TTS_NAME,
|
||||
DOMAIN,
|
||||
RECOMMENDED_AI_TASK_OPTIONS,
|
||||
RECOMMENDED_CONVERSATION_OPTIONS,
|
||||
RECOMMENDED_STT_OPTIONS,
|
||||
RECOMMENDED_TTS_OPTIONS,
|
||||
)
|
||||
from homeassistant.config_entries import (
|
||||
@ -489,7 +491,7 @@ async def test_migration_from_v1(
|
||||
assert entry.minor_version == 4
|
||||
assert not entry.options
|
||||
assert entry.title == DEFAULT_TITLE
|
||||
assert len(entry.subentries) == 4
|
||||
assert len(entry.subentries) == 5
|
||||
conversation_subentries = [
|
||||
subentry
|
||||
for subentry in entry.subentries.values()
|
||||
@ -516,6 +518,14 @@ async def test_migration_from_v1(
|
||||
assert len(ai_task_subentries) == 1
|
||||
assert ai_task_subentries[0].data == RECOMMENDED_AI_TASK_OPTIONS
|
||||
assert ai_task_subentries[0].title == DEFAULT_AI_TASK_NAME
|
||||
stt_subentries = [
|
||||
subentry
|
||||
for subentry in entry.subentries.values()
|
||||
if subentry.subentry_type == "stt"
|
||||
]
|
||||
assert len(stt_subentries) == 1
|
||||
assert stt_subentries[0].data == RECOMMENDED_STT_OPTIONS
|
||||
assert stt_subentries[0].title == DEFAULT_STT_NAME
|
||||
|
||||
subentry = conversation_subentries[0]
|
||||
|
||||
@ -721,7 +731,7 @@ async def test_migration_from_v1_disabled(
|
||||
assert entry.minor_version == 4
|
||||
assert not entry.options
|
||||
assert entry.title == DEFAULT_TITLE
|
||||
assert len(entry.subentries) == 4
|
||||
assert len(entry.subentries) == 5
|
||||
conversation_subentries = [
|
||||
subentry
|
||||
for subentry in entry.subentries.values()
|
||||
@ -748,6 +758,14 @@ async def test_migration_from_v1_disabled(
|
||||
assert len(ai_task_subentries) == 1
|
||||
assert ai_task_subentries[0].data == RECOMMENDED_AI_TASK_OPTIONS
|
||||
assert ai_task_subentries[0].title == DEFAULT_AI_TASK_NAME
|
||||
stt_subentries = [
|
||||
subentry
|
||||
for subentry in entry.subentries.values()
|
||||
if subentry.subentry_type == "stt"
|
||||
]
|
||||
assert len(stt_subentries) == 1
|
||||
assert stt_subentries[0].data == RECOMMENDED_STT_OPTIONS
|
||||
assert stt_subentries[0].title == DEFAULT_STT_NAME
|
||||
|
||||
assert not device_registry.async_get_device(
|
||||
identifiers={(DOMAIN, mock_config_entry.entry_id)}
|
||||
@ -860,7 +878,7 @@ async def test_migration_from_v1_with_multiple_keys(
|
||||
assert entry.minor_version == 4
|
||||
assert not entry.options
|
||||
assert entry.title == DEFAULT_TITLE
|
||||
assert len(entry.subentries) == 3
|
||||
assert len(entry.subentries) == 4
|
||||
subentry = list(entry.subentries.values())[0]
|
||||
assert subentry.subentry_type == "conversation"
|
||||
assert subentry.data == options
|
||||
@ -873,6 +891,10 @@ async def test_migration_from_v1_with_multiple_keys(
|
||||
assert subentry.subentry_type == "ai_task_data"
|
||||
assert subentry.data == RECOMMENDED_AI_TASK_OPTIONS
|
||||
assert subentry.title == DEFAULT_AI_TASK_NAME
|
||||
subentry = list(entry.subentries.values())[3]
|
||||
assert subentry.subentry_type == "stt"
|
||||
assert subentry.data == RECOMMENDED_STT_OPTIONS
|
||||
assert subentry.title == DEFAULT_STT_NAME
|
||||
|
||||
dev = device_registry.async_get_device(
|
||||
identifiers={(DOMAIN, list(entry.subentries.values())[0].subentry_id)}
|
||||
@ -963,7 +985,7 @@ async def test_migration_from_v1_with_same_keys(
|
||||
assert entry.minor_version == 4
|
||||
assert not entry.options
|
||||
assert entry.title == DEFAULT_TITLE
|
||||
assert len(entry.subentries) == 4
|
||||
assert len(entry.subentries) == 5
|
||||
conversation_subentries = [
|
||||
subentry
|
||||
for subentry in entry.subentries.values()
|
||||
@ -990,6 +1012,14 @@ async def test_migration_from_v1_with_same_keys(
|
||||
assert len(ai_task_subentries) == 1
|
||||
assert ai_task_subentries[0].data == RECOMMENDED_AI_TASK_OPTIONS
|
||||
assert ai_task_subentries[0].title == DEFAULT_AI_TASK_NAME
|
||||
stt_subentries = [
|
||||
subentry
|
||||
for subentry in entry.subentries.values()
|
||||
if subentry.subentry_type == "stt"
|
||||
]
|
||||
assert len(stt_subentries) == 1
|
||||
assert stt_subentries[0].data == RECOMMENDED_STT_OPTIONS
|
||||
assert stt_subentries[0].title == DEFAULT_STT_NAME
|
||||
|
||||
subentry = conversation_subentries[0]
|
||||
|
||||
@ -1090,10 +1120,11 @@ async def test_migration_from_v2_1(
|
||||
"""Test migration from version 2.1.
|
||||
|
||||
This tests we clean up the broken migration in Home Assistant Core
|
||||
2025.7.0b0-2025.7.0b1 and add AI Task subentry:
|
||||
2025.7.0b0-2025.7.0b1 and add AI Task and STT subentries:
|
||||
- Fix device registry (Fixed in Home Assistant Core 2025.7.0b2)
|
||||
- Add TTS subentry (Added in Home Assistant Core 2025.7.0b1)
|
||||
- Add AI Task subentry (Added in version 2.3)
|
||||
- Add STT subentry (Added in version 2.3)
|
||||
"""
|
||||
# Create a v2.1 config entry with 2 subentries, devices and entities
|
||||
options = {
|
||||
@ -1184,7 +1215,7 @@ async def test_migration_from_v2_1(
|
||||
assert entry.minor_version == 4
|
||||
assert not entry.options
|
||||
assert entry.title == DEFAULT_TITLE
|
||||
assert len(entry.subentries) == 4
|
||||
assert len(entry.subentries) == 5
|
||||
conversation_subentries = [
|
||||
subentry
|
||||
for subentry in entry.subentries.values()
|
||||
@ -1211,6 +1242,14 @@ async def test_migration_from_v2_1(
|
||||
assert len(ai_task_subentries) == 1
|
||||
assert ai_task_subentries[0].data == RECOMMENDED_AI_TASK_OPTIONS
|
||||
assert ai_task_subentries[0].title == DEFAULT_AI_TASK_NAME
|
||||
stt_subentries = [
|
||||
subentry
|
||||
for subentry in entry.subentries.values()
|
||||
if subentry.subentry_type == "stt"
|
||||
]
|
||||
assert len(stt_subentries) == 1
|
||||
assert stt_subentries[0].data == RECOMMENDED_STT_OPTIONS
|
||||
assert stt_subentries[0].title == DEFAULT_STT_NAME
|
||||
|
||||
subentry = conversation_subentries[0]
|
||||
|
||||
@ -1320,8 +1359,8 @@ async def test_migrate_entry_from_v2_2(hass: HomeAssistant) -> None:
|
||||
assert entry.version == 2
|
||||
assert entry.minor_version == 4
|
||||
|
||||
# Check we now have conversation, tts and ai_task_data subentries
|
||||
assert len(entry.subentries) == 3
|
||||
# Check we now have conversation, tts, stt, and ai_task_data subentries
|
||||
assert len(entry.subentries) == 4
|
||||
|
||||
subentries = {
|
||||
subentry.subentry_type: subentry for subentry in entry.subentries.values()
|
||||
@ -1336,6 +1375,12 @@ async def test_migrate_entry_from_v2_2(hass: HomeAssistant) -> None:
|
||||
assert ai_task_subentry.title == DEFAULT_AI_TASK_NAME
|
||||
assert ai_task_subentry.data == RECOMMENDED_AI_TASK_OPTIONS
|
||||
|
||||
# Find and verify the stt subentry
|
||||
ai_task_subentry = subentries["stt"]
|
||||
assert ai_task_subentry is not None
|
||||
assert ai_task_subentry.title == DEFAULT_STT_NAME
|
||||
assert ai_task_subentry.data == RECOMMENDED_STT_OPTIONS
|
||||
|
||||
# Verify conversation subentry is still there and unchanged
|
||||
conversation_subentry = subentries["conversation"]
|
||||
assert conversation_subentry is not None
|
||||
|
303
tests/components/google_generative_ai_conversation/test_stt.py
Normal file
303
tests/components/google_generative_ai_conversation/test_stt.py
Normal file
@ -0,0 +1,303 @@
|
||||
"""Tests for the Google Generative AI Conversation STT entity."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import AsyncIterable, Generator
|
||||
from unittest.mock import AsyncMock, Mock, patch
|
||||
|
||||
from google.genai import types
|
||||
import pytest
|
||||
|
||||
from homeassistant.components import stt
|
||||
from homeassistant.components.google_generative_ai_conversation.const import (
|
||||
CONF_CHAT_MODEL,
|
||||
CONF_PROMPT,
|
||||
DEFAULT_STT_PROMPT,
|
||||
DOMAIN,
|
||||
RECOMMENDED_STT_MODEL,
|
||||
)
|
||||
from homeassistant.config_entries import ConfigSubentry
|
||||
from homeassistant.const import CONF_API_KEY
|
||||
from homeassistant.core import HomeAssistant
|
||||
|
||||
from . import API_ERROR_500, CLIENT_ERROR_BAD_REQUEST
|
||||
|
||||
from tests.common import MockConfigEntry
|
||||
|
||||
TEST_CHAT_MODEL = "models/gemini-2.5-flash"
|
||||
TEST_PROMPT = "Please transcribe the audio."
|
||||
|
||||
|
||||
async def _async_get_audio_stream(data: bytes) -> AsyncIterable[bytes]:
|
||||
"""Yield the audio data."""
|
||||
yield data
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_genai_client() -> Generator[AsyncMock]:
|
||||
"""Mock genai.Client."""
|
||||
client = Mock()
|
||||
client.aio.models.get = AsyncMock()
|
||||
client.aio.models.generate_content = AsyncMock(
|
||||
return_value=types.GenerateContentResponse(
|
||||
candidates=[
|
||||
{
|
||||
"content": {
|
||||
"parts": [{"text": "This is a test transcription."}],
|
||||
"role": "model",
|
||||
}
|
||||
}
|
||||
]
|
||||
)
|
||||
)
|
||||
with patch(
|
||||
"homeassistant.components.google_generative_ai_conversation.Client",
|
||||
return_value=client,
|
||||
) as mock_client:
|
||||
yield mock_client.return_value
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def setup_integration(
|
||||
hass: HomeAssistant,
|
||||
mock_genai_client: AsyncMock,
|
||||
) -> None:
|
||||
"""Set up the test environment."""
|
||||
config_entry = MockConfigEntry(
|
||||
domain=DOMAIN, data={CONF_API_KEY: "bla"}, version=2, minor_version=1
|
||||
)
|
||||
config_entry.add_to_hass(hass)
|
||||
|
||||
sub_entry = ConfigSubentry(
|
||||
data={
|
||||
CONF_CHAT_MODEL: TEST_CHAT_MODEL,
|
||||
CONF_PROMPT: TEST_PROMPT,
|
||||
},
|
||||
subentry_type="stt",
|
||||
title="Google AI STT",
|
||||
unique_id=None,
|
||||
)
|
||||
|
||||
config_entry.runtime_data = mock_genai_client
|
||||
|
||||
hass.config_entries.async_add_subentry(config_entry, sub_entry)
|
||||
await hass.config_entries.async_setup(config_entry.entry_id)
|
||||
await hass.async_block_till_done()
|
||||
|
||||
|
||||
@pytest.mark.usefixtures("setup_integration")
|
||||
async def test_stt_entity_properties(hass: HomeAssistant) -> None:
|
||||
"""Test STT entity properties."""
|
||||
entity: stt.SpeechToTextEntity = hass.data[stt.DOMAIN].get_entity(
|
||||
"stt.google_ai_stt"
|
||||
)
|
||||
assert entity is not None
|
||||
assert isinstance(entity.supported_languages, list)
|
||||
assert stt.AudioFormats.WAV in entity.supported_formats
|
||||
assert stt.AudioFormats.OGG in entity.supported_formats
|
||||
assert stt.AudioCodecs.PCM in entity.supported_codecs
|
||||
assert stt.AudioCodecs.OPUS in entity.supported_codecs
|
||||
assert stt.AudioBitRates.BITRATE_16 in entity.supported_bit_rates
|
||||
assert stt.AudioSampleRates.SAMPLERATE_16000 in entity.supported_sample_rates
|
||||
assert stt.AudioChannels.CHANNEL_MONO in entity.supported_channels
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("audio_format", "call_convert_to_wav"),
|
||||
[
|
||||
(stt.AudioFormats.WAV, True),
|
||||
(stt.AudioFormats.OGG, False),
|
||||
],
|
||||
)
|
||||
@pytest.mark.usefixtures("setup_integration")
|
||||
async def test_stt_process_audio_stream_success(
|
||||
hass: HomeAssistant,
|
||||
mock_genai_client: AsyncMock,
|
||||
audio_format: stt.AudioFormats,
|
||||
call_convert_to_wav: bool,
|
||||
) -> None:
|
||||
"""Test STT processing audio stream successfully."""
|
||||
entity = hass.data[stt.DOMAIN].get_entity("stt.google_ai_stt")
|
||||
|
||||
metadata = stt.SpeechMetadata(
|
||||
language="en-US",
|
||||
format=audio_format,
|
||||
codec=stt.AudioCodecs.PCM,
|
||||
bit_rate=stt.AudioBitRates.BITRATE_16,
|
||||
sample_rate=stt.AudioSampleRates.SAMPLERATE_16000,
|
||||
channel=stt.AudioChannels.CHANNEL_MONO,
|
||||
)
|
||||
audio_stream = _async_get_audio_stream(b"test_audio_bytes")
|
||||
|
||||
with patch(
|
||||
"homeassistant.components.google_generative_ai_conversation.stt.convert_to_wav",
|
||||
return_value=b"converted_wav_bytes",
|
||||
) as mock_convert_to_wav:
|
||||
result = await entity.async_process_audio_stream(metadata, audio_stream)
|
||||
|
||||
assert result.result == stt.SpeechResultState.SUCCESS
|
||||
assert result.text == "This is a test transcription."
|
||||
|
||||
if call_convert_to_wav:
|
||||
mock_convert_to_wav.assert_called_once_with(
|
||||
b"test_audio_bytes", "audio/L16;rate=16000"
|
||||
)
|
||||
else:
|
||||
mock_convert_to_wav.assert_not_called()
|
||||
|
||||
mock_genai_client.aio.models.generate_content.assert_called_once()
|
||||
call_args = mock_genai_client.aio.models.generate_content.call_args
|
||||
assert call_args.kwargs["model"] == TEST_CHAT_MODEL
|
||||
|
||||
contents = call_args.kwargs["contents"]
|
||||
assert contents[0] == TEST_PROMPT
|
||||
assert isinstance(contents[1], types.Part)
|
||||
assert contents[1].inline_data.mime_type == f"audio/{audio_format.value}"
|
||||
if call_convert_to_wav:
|
||||
assert contents[1].inline_data.data == b"converted_wav_bytes"
|
||||
else:
|
||||
assert contents[1].inline_data.data == b"test_audio_bytes"
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"side_effect",
|
||||
[
|
||||
API_ERROR_500,
|
||||
CLIENT_ERROR_BAD_REQUEST,
|
||||
ValueError("Test value error"),
|
||||
],
|
||||
)
|
||||
@pytest.mark.usefixtures("setup_integration")
|
||||
async def test_stt_process_audio_stream_api_error(
|
||||
hass: HomeAssistant,
|
||||
mock_genai_client: AsyncMock,
|
||||
side_effect: Exception,
|
||||
) -> None:
|
||||
"""Test STT processing audio stream with API errors."""
|
||||
entity = hass.data[stt.DOMAIN].get_entity("stt.google_ai_stt")
|
||||
mock_genai_client.aio.models.generate_content.side_effect = side_effect
|
||||
|
||||
metadata = stt.SpeechMetadata(
|
||||
language="en-US",
|
||||
format=stt.AudioFormats.OGG,
|
||||
codec=stt.AudioCodecs.OPUS,
|
||||
bit_rate=stt.AudioBitRates.BITRATE_16,
|
||||
sample_rate=stt.AudioSampleRates.SAMPLERATE_16000,
|
||||
channel=stt.AudioChannels.CHANNEL_MONO,
|
||||
)
|
||||
audio_stream = _async_get_audio_stream(b"test_audio_bytes")
|
||||
|
||||
result = await entity.async_process_audio_stream(metadata, audio_stream)
|
||||
|
||||
assert result.result == stt.SpeechResultState.ERROR
|
||||
assert result.text is None
|
||||
|
||||
|
||||
@pytest.mark.usefixtures("setup_integration")
|
||||
async def test_stt_process_audio_stream_empty_response(
|
||||
hass: HomeAssistant,
|
||||
mock_genai_client: AsyncMock,
|
||||
) -> None:
|
||||
"""Test STT processing with an empty response from the API."""
|
||||
entity = hass.data[stt.DOMAIN].get_entity("stt.google_ai_stt")
|
||||
mock_genai_client.aio.models.generate_content.return_value = (
|
||||
types.GenerateContentResponse(candidates=[])
|
||||
)
|
||||
|
||||
metadata = stt.SpeechMetadata(
|
||||
language="en-US",
|
||||
format=stt.AudioFormats.OGG,
|
||||
codec=stt.AudioCodecs.OPUS,
|
||||
bit_rate=stt.AudioBitRates.BITRATE_16,
|
||||
sample_rate=stt.AudioSampleRates.SAMPLERATE_16000,
|
||||
channel=stt.AudioChannels.CHANNEL_MONO,
|
||||
)
|
||||
audio_stream = _async_get_audio_stream(b"test_audio_bytes")
|
||||
|
||||
result = await entity.async_process_audio_stream(metadata, audio_stream)
|
||||
|
||||
assert result.result == stt.SpeechResultState.ERROR
|
||||
assert result.text is None
|
||||
|
||||
|
||||
@pytest.mark.usefixtures("mock_genai_client")
|
||||
async def test_stt_uses_default_prompt(
|
||||
hass: HomeAssistant,
|
||||
mock_genai_client: AsyncMock,
|
||||
) -> None:
|
||||
"""Test that the default prompt is used if none is configured."""
|
||||
config_entry = MockConfigEntry(
|
||||
domain=DOMAIN, data={CONF_API_KEY: "bla"}, version=2, minor_version=1
|
||||
)
|
||||
config_entry.add_to_hass(hass)
|
||||
config_entry.runtime_data = mock_genai_client
|
||||
|
||||
# Subentry with no prompt
|
||||
sub_entry = ConfigSubentry(
|
||||
data={CONF_CHAT_MODEL: TEST_CHAT_MODEL},
|
||||
subentry_type="stt",
|
||||
title="Google AI STT",
|
||||
unique_id=None,
|
||||
)
|
||||
hass.config_entries.async_add_subentry(config_entry, sub_entry)
|
||||
await hass.config_entries.async_setup(config_entry.entry_id)
|
||||
await hass.async_block_till_done()
|
||||
|
||||
entity = hass.data[stt.DOMAIN].get_entity("stt.google_ai_stt")
|
||||
|
||||
metadata = stt.SpeechMetadata(
|
||||
language="en-US",
|
||||
format=stt.AudioFormats.OGG,
|
||||
codec=stt.AudioCodecs.OPUS,
|
||||
bit_rate=stt.AudioBitRates.BITRATE_16,
|
||||
sample_rate=stt.AudioSampleRates.SAMPLERATE_16000,
|
||||
channel=stt.AudioChannels.CHANNEL_MONO,
|
||||
)
|
||||
audio_stream = _async_get_audio_stream(b"test_audio_bytes")
|
||||
|
||||
await entity.async_process_audio_stream(metadata, audio_stream)
|
||||
|
||||
call_args = mock_genai_client.aio.models.generate_content.call_args
|
||||
contents = call_args.kwargs["contents"]
|
||||
assert contents[0] == DEFAULT_STT_PROMPT
|
||||
|
||||
|
||||
@pytest.mark.usefixtures("mock_genai_client")
|
||||
async def test_stt_uses_default_model(
|
||||
hass: HomeAssistant,
|
||||
mock_genai_client: AsyncMock,
|
||||
) -> None:
|
||||
"""Test that the default model is used if none is configured."""
|
||||
config_entry = MockConfigEntry(
|
||||
domain=DOMAIN, data={CONF_API_KEY: "bla"}, version=2, minor_version=1
|
||||
)
|
||||
config_entry.add_to_hass(hass)
|
||||
config_entry.runtime_data = mock_genai_client
|
||||
|
||||
# Subentry with no model
|
||||
sub_entry = ConfigSubentry(
|
||||
data={CONF_PROMPT: TEST_PROMPT},
|
||||
subentry_type="stt",
|
||||
title="Google AI STT",
|
||||
unique_id=None,
|
||||
)
|
||||
hass.config_entries.async_add_subentry(config_entry, sub_entry)
|
||||
await hass.config_entries.async_setup(config_entry.entry_id)
|
||||
await hass.async_block_till_done()
|
||||
|
||||
entity = hass.data[stt.DOMAIN].get_entity("stt.google_ai_stt")
|
||||
|
||||
metadata = stt.SpeechMetadata(
|
||||
language="en-US",
|
||||
format=stt.AudioFormats.OGG,
|
||||
codec=stt.AudioCodecs.OPUS,
|
||||
bit_rate=stt.AudioBitRates.BITRATE_16,
|
||||
sample_rate=stt.AudioSampleRates.SAMPLERATE_16000,
|
||||
channel=stt.AudioChannels.CHANNEL_MONO,
|
||||
)
|
||||
audio_stream = _async_get_audio_stream(b"test_audio_bytes")
|
||||
|
||||
await entity.async_process_audio_stream(metadata, audio_stream)
|
||||
|
||||
call_args = mock_genai_client.aio.models.generate_content.call_args
|
||||
assert call_args.kwargs["model"] == RECOMMENDED_STT_MODEL
|
Loading…
x
Reference in New Issue
Block a user