mirror of
https://github.com/home-assistant/core.git
synced 2025-07-19 11:17:21 +00:00
Add Google AI STT (#147563)
This commit is contained in:
parent
26a9af7371
commit
02a11638b3
@ -36,12 +36,14 @@ from homeassistant.helpers.typing import ConfigType
|
|||||||
from .const import (
|
from .const import (
|
||||||
CONF_PROMPT,
|
CONF_PROMPT,
|
||||||
DEFAULT_AI_TASK_NAME,
|
DEFAULT_AI_TASK_NAME,
|
||||||
|
DEFAULT_STT_NAME,
|
||||||
DEFAULT_TITLE,
|
DEFAULT_TITLE,
|
||||||
DEFAULT_TTS_NAME,
|
DEFAULT_TTS_NAME,
|
||||||
DOMAIN,
|
DOMAIN,
|
||||||
LOGGER,
|
LOGGER,
|
||||||
RECOMMENDED_AI_TASK_OPTIONS,
|
RECOMMENDED_AI_TASK_OPTIONS,
|
||||||
RECOMMENDED_CHAT_MODEL,
|
RECOMMENDED_CHAT_MODEL,
|
||||||
|
RECOMMENDED_STT_OPTIONS,
|
||||||
RECOMMENDED_TTS_OPTIONS,
|
RECOMMENDED_TTS_OPTIONS,
|
||||||
TIMEOUT_MILLIS,
|
TIMEOUT_MILLIS,
|
||||||
)
|
)
|
||||||
@ -55,6 +57,7 @@ CONFIG_SCHEMA = cv.config_entry_only_config_schema(DOMAIN)
|
|||||||
PLATFORMS = (
|
PLATFORMS = (
|
||||||
Platform.AI_TASK,
|
Platform.AI_TASK,
|
||||||
Platform.CONVERSATION,
|
Platform.CONVERSATION,
|
||||||
|
Platform.STT,
|
||||||
Platform.TTS,
|
Platform.TTS,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -301,7 +304,7 @@ async def async_migrate_integration(hass: HomeAssistant) -> None:
|
|||||||
if not use_existing:
|
if not use_existing:
|
||||||
await hass.config_entries.async_remove(entry.entry_id)
|
await hass.config_entries.async_remove(entry.entry_id)
|
||||||
else:
|
else:
|
||||||
_add_ai_task_subentry(hass, entry)
|
_add_ai_task_and_stt_subentries(hass, entry)
|
||||||
hass.config_entries.async_update_entry(
|
hass.config_entries.async_update_entry(
|
||||||
entry,
|
entry,
|
||||||
title=DEFAULT_TITLE,
|
title=DEFAULT_TITLE,
|
||||||
@ -350,8 +353,7 @@ async def async_migrate_entry(
|
|||||||
hass.config_entries.async_update_entry(entry, minor_version=2)
|
hass.config_entries.async_update_entry(entry, minor_version=2)
|
||||||
|
|
||||||
if entry.version == 2 and entry.minor_version == 2:
|
if entry.version == 2 and entry.minor_version == 2:
|
||||||
# Add AI Task subentry with default options
|
_add_ai_task_and_stt_subentries(hass, entry)
|
||||||
_add_ai_task_subentry(hass, entry)
|
|
||||||
hass.config_entries.async_update_entry(entry, minor_version=3)
|
hass.config_entries.async_update_entry(entry, minor_version=3)
|
||||||
|
|
||||||
if entry.version == 2 and entry.minor_version == 3:
|
if entry.version == 2 and entry.minor_version == 3:
|
||||||
@ -393,10 +395,10 @@ async def async_migrate_entry(
|
|||||||
return True
|
return True
|
||||||
|
|
||||||
|
|
||||||
def _add_ai_task_subentry(
|
def _add_ai_task_and_stt_subentries(
|
||||||
hass: HomeAssistant, entry: GoogleGenerativeAIConfigEntry
|
hass: HomeAssistant, entry: GoogleGenerativeAIConfigEntry
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Add AI Task subentry to the config entry."""
|
"""Add AI Task and STT subentries to the config entry."""
|
||||||
hass.config_entries.async_add_subentry(
|
hass.config_entries.async_add_subentry(
|
||||||
entry,
|
entry,
|
||||||
ConfigSubentry(
|
ConfigSubentry(
|
||||||
@ -406,3 +408,12 @@ def _add_ai_task_subentry(
|
|||||||
unique_id=None,
|
unique_id=None,
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
hass.config_entries.async_add_subentry(
|
||||||
|
entry,
|
||||||
|
ConfigSubentry(
|
||||||
|
data=MappingProxyType(RECOMMENDED_STT_OPTIONS),
|
||||||
|
subentry_type="stt",
|
||||||
|
title=DEFAULT_STT_NAME,
|
||||||
|
unique_id=None,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
@ -49,6 +49,8 @@ from .const import (
|
|||||||
CONF_USE_GOOGLE_SEARCH_TOOL,
|
CONF_USE_GOOGLE_SEARCH_TOOL,
|
||||||
DEFAULT_AI_TASK_NAME,
|
DEFAULT_AI_TASK_NAME,
|
||||||
DEFAULT_CONVERSATION_NAME,
|
DEFAULT_CONVERSATION_NAME,
|
||||||
|
DEFAULT_STT_NAME,
|
||||||
|
DEFAULT_STT_PROMPT,
|
||||||
DEFAULT_TITLE,
|
DEFAULT_TITLE,
|
||||||
DEFAULT_TTS_NAME,
|
DEFAULT_TTS_NAME,
|
||||||
DOMAIN,
|
DOMAIN,
|
||||||
@ -57,6 +59,8 @@ from .const import (
|
|||||||
RECOMMENDED_CONVERSATION_OPTIONS,
|
RECOMMENDED_CONVERSATION_OPTIONS,
|
||||||
RECOMMENDED_HARM_BLOCK_THRESHOLD,
|
RECOMMENDED_HARM_BLOCK_THRESHOLD,
|
||||||
RECOMMENDED_MAX_TOKENS,
|
RECOMMENDED_MAX_TOKENS,
|
||||||
|
RECOMMENDED_STT_MODEL,
|
||||||
|
RECOMMENDED_STT_OPTIONS,
|
||||||
RECOMMENDED_TEMPERATURE,
|
RECOMMENDED_TEMPERATURE,
|
||||||
RECOMMENDED_TOP_K,
|
RECOMMENDED_TOP_K,
|
||||||
RECOMMENDED_TOP_P,
|
RECOMMENDED_TOP_P,
|
||||||
@ -144,6 +148,12 @@ class GoogleGenerativeAIConfigFlow(ConfigFlow, domain=DOMAIN):
|
|||||||
"title": DEFAULT_AI_TASK_NAME,
|
"title": DEFAULT_AI_TASK_NAME,
|
||||||
"unique_id": None,
|
"unique_id": None,
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
"subentry_type": "stt",
|
||||||
|
"data": RECOMMENDED_STT_OPTIONS,
|
||||||
|
"title": DEFAULT_STT_NAME,
|
||||||
|
"unique_id": None,
|
||||||
|
},
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
return self.async_show_form(
|
return self.async_show_form(
|
||||||
@ -191,6 +201,7 @@ class GoogleGenerativeAIConfigFlow(ConfigFlow, domain=DOMAIN):
|
|||||||
"""Return subentries supported by this integration."""
|
"""Return subentries supported by this integration."""
|
||||||
return {
|
return {
|
||||||
"conversation": LLMSubentryFlowHandler,
|
"conversation": LLMSubentryFlowHandler,
|
||||||
|
"stt": LLMSubentryFlowHandler,
|
||||||
"tts": LLMSubentryFlowHandler,
|
"tts": LLMSubentryFlowHandler,
|
||||||
"ai_task_data": LLMSubentryFlowHandler,
|
"ai_task_data": LLMSubentryFlowHandler,
|
||||||
}
|
}
|
||||||
@ -228,6 +239,8 @@ class LLMSubentryFlowHandler(ConfigSubentryFlow):
|
|||||||
options = RECOMMENDED_TTS_OPTIONS.copy()
|
options = RECOMMENDED_TTS_OPTIONS.copy()
|
||||||
elif self._subentry_type == "ai_task_data":
|
elif self._subentry_type == "ai_task_data":
|
||||||
options = RECOMMENDED_AI_TASK_OPTIONS.copy()
|
options = RECOMMENDED_AI_TASK_OPTIONS.copy()
|
||||||
|
elif self._subentry_type == "stt":
|
||||||
|
options = RECOMMENDED_STT_OPTIONS.copy()
|
||||||
else:
|
else:
|
||||||
options = RECOMMENDED_CONVERSATION_OPTIONS.copy()
|
options = RECOMMENDED_CONVERSATION_OPTIONS.copy()
|
||||||
else:
|
else:
|
||||||
@ -304,6 +317,8 @@ async def google_generative_ai_config_option_schema(
|
|||||||
default_name = DEFAULT_TTS_NAME
|
default_name = DEFAULT_TTS_NAME
|
||||||
elif subentry_type == "ai_task_data":
|
elif subentry_type == "ai_task_data":
|
||||||
default_name = DEFAULT_AI_TASK_NAME
|
default_name = DEFAULT_AI_TASK_NAME
|
||||||
|
elif subentry_type == "stt":
|
||||||
|
default_name = DEFAULT_STT_NAME
|
||||||
else:
|
else:
|
||||||
default_name = DEFAULT_CONVERSATION_NAME
|
default_name = DEFAULT_CONVERSATION_NAME
|
||||||
schema: dict[vol.Required | vol.Optional, Any] = {
|
schema: dict[vol.Required | vol.Optional, Any] = {
|
||||||
@ -331,6 +346,17 @@ async def google_generative_ai_config_option_schema(
|
|||||||
),
|
),
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
elif subentry_type == "stt":
|
||||||
|
schema.update(
|
||||||
|
{
|
||||||
|
vol.Optional(
|
||||||
|
CONF_PROMPT,
|
||||||
|
description={
|
||||||
|
"suggested_value": options.get(CONF_PROMPT, DEFAULT_STT_PROMPT)
|
||||||
|
},
|
||||||
|
): TemplateSelector(),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
schema.update(
|
schema.update(
|
||||||
{
|
{
|
||||||
@ -388,6 +414,8 @@ async def google_generative_ai_config_option_schema(
|
|||||||
|
|
||||||
if subentry_type == "tts":
|
if subentry_type == "tts":
|
||||||
default_model = RECOMMENDED_TTS_MODEL
|
default_model = RECOMMENDED_TTS_MODEL
|
||||||
|
elif subentry_type == "stt":
|
||||||
|
default_model = RECOMMENDED_STT_MODEL
|
||||||
else:
|
else:
|
||||||
default_model = RECOMMENDED_CHAT_MODEL
|
default_model = RECOMMENDED_CHAT_MODEL
|
||||||
|
|
||||||
|
@ -5,18 +5,23 @@ import logging
|
|||||||
from homeassistant.const import CONF_LLM_HASS_API
|
from homeassistant.const import CONF_LLM_HASS_API
|
||||||
from homeassistant.helpers import llm
|
from homeassistant.helpers import llm
|
||||||
|
|
||||||
|
LOGGER = logging.getLogger(__package__)
|
||||||
|
|
||||||
DOMAIN = "google_generative_ai_conversation"
|
DOMAIN = "google_generative_ai_conversation"
|
||||||
DEFAULT_TITLE = "Google Generative AI"
|
DEFAULT_TITLE = "Google Generative AI"
|
||||||
LOGGER = logging.getLogger(__package__)
|
|
||||||
CONF_PROMPT = "prompt"
|
|
||||||
|
|
||||||
DEFAULT_CONVERSATION_NAME = "Google AI Conversation"
|
DEFAULT_CONVERSATION_NAME = "Google AI Conversation"
|
||||||
|
DEFAULT_STT_NAME = "Google AI STT"
|
||||||
DEFAULT_TTS_NAME = "Google AI TTS"
|
DEFAULT_TTS_NAME = "Google AI TTS"
|
||||||
DEFAULT_AI_TASK_NAME = "Google AI Task"
|
DEFAULT_AI_TASK_NAME = "Google AI Task"
|
||||||
|
|
||||||
|
CONF_PROMPT = "prompt"
|
||||||
|
DEFAULT_STT_PROMPT = "Transcribe the attached audio"
|
||||||
|
|
||||||
CONF_RECOMMENDED = "recommended"
|
CONF_RECOMMENDED = "recommended"
|
||||||
CONF_CHAT_MODEL = "chat_model"
|
CONF_CHAT_MODEL = "chat_model"
|
||||||
RECOMMENDED_CHAT_MODEL = "models/gemini-2.5-flash"
|
RECOMMENDED_CHAT_MODEL = "models/gemini-2.5-flash"
|
||||||
|
RECOMMENDED_STT_MODEL = RECOMMENDED_CHAT_MODEL
|
||||||
RECOMMENDED_TTS_MODEL = "models/gemini-2.5-flash-preview-tts"
|
RECOMMENDED_TTS_MODEL = "models/gemini-2.5-flash-preview-tts"
|
||||||
CONF_TEMPERATURE = "temperature"
|
CONF_TEMPERATURE = "temperature"
|
||||||
RECOMMENDED_TEMPERATURE = 1.0
|
RECOMMENDED_TEMPERATURE = 1.0
|
||||||
@ -43,6 +48,11 @@ RECOMMENDED_CONVERSATION_OPTIONS = {
|
|||||||
CONF_RECOMMENDED: True,
|
CONF_RECOMMENDED: True,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
RECOMMENDED_STT_OPTIONS = {
|
||||||
|
CONF_PROMPT: DEFAULT_STT_PROMPT,
|
||||||
|
CONF_RECOMMENDED: True,
|
||||||
|
}
|
||||||
|
|
||||||
RECOMMENDED_TTS_OPTIONS = {
|
RECOMMENDED_TTS_OPTIONS = {
|
||||||
CONF_RECOMMENDED: True,
|
CONF_RECOMMENDED: True,
|
||||||
}
|
}
|
||||||
|
@ -61,6 +61,38 @@
|
|||||||
"invalid_google_search_option": "Google Search can only be enabled if nothing is selected in the \"Control Home Assistant\" setting."
|
"invalid_google_search_option": "Google Search can only be enabled if nothing is selected in the \"Control Home Assistant\" setting."
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
"stt": {
|
||||||
|
"initiate_flow": {
|
||||||
|
"user": "Add Speech-to-Text service",
|
||||||
|
"reconfigure": "Reconfigure Speech-to-Text service"
|
||||||
|
},
|
||||||
|
"entry_type": "Speech-to-Text",
|
||||||
|
"step": {
|
||||||
|
"set_options": {
|
||||||
|
"data": {
|
||||||
|
"name": "[%key:common::config_flow::data::name%]",
|
||||||
|
"recommended": "[%key:component::google_generative_ai_conversation::config_subentries::conversation::step::set_options::data::recommended%]",
|
||||||
|
"prompt": "Instructions",
|
||||||
|
"chat_model": "[%key:common::generic::model%]",
|
||||||
|
"temperature": "[%key:component::google_generative_ai_conversation::config_subentries::conversation::step::set_options::data::temperature%]",
|
||||||
|
"top_p": "[%key:component::google_generative_ai_conversation::config_subentries::conversation::step::set_options::data::top_p%]",
|
||||||
|
"top_k": "[%key:component::google_generative_ai_conversation::config_subentries::conversation::step::set_options::data::top_k%]",
|
||||||
|
"max_tokens": "[%key:component::google_generative_ai_conversation::config_subentries::conversation::step::set_options::data::max_tokens%]",
|
||||||
|
"harassment_block_threshold": "[%key:component::google_generative_ai_conversation::config_subentries::conversation::step::set_options::data::harassment_block_threshold%]",
|
||||||
|
"hate_block_threshold": "[%key:component::google_generative_ai_conversation::config_subentries::conversation::step::set_options::data::hate_block_threshold%]",
|
||||||
|
"sexual_block_threshold": "[%key:component::google_generative_ai_conversation::config_subentries::conversation::step::set_options::data::sexual_block_threshold%]",
|
||||||
|
"dangerous_block_threshold": "[%key:component::google_generative_ai_conversation::config_subentries::conversation::step::set_options::data::dangerous_block_threshold%]"
|
||||||
|
},
|
||||||
|
"data_description": {
|
||||||
|
"prompt": "Instruct how the LLM should transcribe the audio."
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"abort": {
|
||||||
|
"entry_not_loaded": "[%key:component::google_generative_ai_conversation::config_subentries::conversation::abort::entry_not_loaded%]",
|
||||||
|
"reconfigure_successful": "[%key:common::config_flow::abort::reconfigure_successful%]"
|
||||||
|
}
|
||||||
|
},
|
||||||
"tts": {
|
"tts": {
|
||||||
"initiate_flow": {
|
"initiate_flow": {
|
||||||
"user": "Add Text-to-Speech service",
|
"user": "Add Text-to-Speech service",
|
||||||
|
@ -0,0 +1,254 @@
|
|||||||
|
"""Speech to text support for Google Generative AI."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from collections.abc import AsyncIterable
|
||||||
|
|
||||||
|
from google.genai.errors import APIError, ClientError
|
||||||
|
from google.genai.types import Part
|
||||||
|
|
||||||
|
from homeassistant.components import stt
|
||||||
|
from homeassistant.config_entries import ConfigEntry, ConfigSubentry
|
||||||
|
from homeassistant.core import HomeAssistant
|
||||||
|
from homeassistant.helpers.entity_platform import AddConfigEntryEntitiesCallback
|
||||||
|
|
||||||
|
from .const import (
|
||||||
|
CONF_CHAT_MODEL,
|
||||||
|
CONF_PROMPT,
|
||||||
|
DEFAULT_STT_PROMPT,
|
||||||
|
LOGGER,
|
||||||
|
RECOMMENDED_STT_MODEL,
|
||||||
|
)
|
||||||
|
from .entity import GoogleGenerativeAILLMBaseEntity
|
||||||
|
from .helpers import convert_to_wav
|
||||||
|
|
||||||
|
|
||||||
|
async def async_setup_entry(
|
||||||
|
hass: HomeAssistant,
|
||||||
|
config_entry: ConfigEntry,
|
||||||
|
async_add_entities: AddConfigEntryEntitiesCallback,
|
||||||
|
) -> None:
|
||||||
|
"""Set up STT entities."""
|
||||||
|
for subentry in config_entry.subentries.values():
|
||||||
|
if subentry.subentry_type != "stt":
|
||||||
|
continue
|
||||||
|
|
||||||
|
async_add_entities(
|
||||||
|
[GoogleGenerativeAISttEntity(config_entry, subentry)],
|
||||||
|
config_subentry_id=subentry.subentry_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class GoogleGenerativeAISttEntity(
|
||||||
|
stt.SpeechToTextEntity, GoogleGenerativeAILLMBaseEntity
|
||||||
|
):
|
||||||
|
"""Google Generative AI speech-to-text entity."""
|
||||||
|
|
||||||
|
def __init__(self, config_entry: ConfigEntry, subentry: ConfigSubentry) -> None:
|
||||||
|
"""Initialize the STT entity."""
|
||||||
|
super().__init__(config_entry, subentry, RECOMMENDED_STT_MODEL)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def supported_languages(self) -> list[str]:
|
||||||
|
"""Return a list of supported languages."""
|
||||||
|
return [
|
||||||
|
"af-ZA",
|
||||||
|
"sq-AL",
|
||||||
|
"am-ET",
|
||||||
|
"ar-DZ",
|
||||||
|
"ar-BH",
|
||||||
|
"ar-EG",
|
||||||
|
"ar-IQ",
|
||||||
|
"ar-IL",
|
||||||
|
"ar-JO",
|
||||||
|
"ar-KW",
|
||||||
|
"ar-LB",
|
||||||
|
"ar-MA",
|
||||||
|
"ar-OM",
|
||||||
|
"ar-QA",
|
||||||
|
"ar-SA",
|
||||||
|
"ar-PS",
|
||||||
|
"ar-TN",
|
||||||
|
"ar-AE",
|
||||||
|
"ar-YE",
|
||||||
|
"hy-AM",
|
||||||
|
"az-AZ",
|
||||||
|
"eu-ES",
|
||||||
|
"bn-BD",
|
||||||
|
"bn-IN",
|
||||||
|
"bs-BA",
|
||||||
|
"bg-BG",
|
||||||
|
"my-MM",
|
||||||
|
"ca-ES",
|
||||||
|
"zh-CN",
|
||||||
|
"zh-TW",
|
||||||
|
"hr-HR",
|
||||||
|
"cs-CZ",
|
||||||
|
"da-DK",
|
||||||
|
"nl-BE",
|
||||||
|
"nl-NL",
|
||||||
|
"en-AU",
|
||||||
|
"en-CA",
|
||||||
|
"en-GH",
|
||||||
|
"en-HK",
|
||||||
|
"en-IN",
|
||||||
|
"en-IE",
|
||||||
|
"en-KE",
|
||||||
|
"en-NZ",
|
||||||
|
"en-NG",
|
||||||
|
"en-PK",
|
||||||
|
"en-PH",
|
||||||
|
"en-SG",
|
||||||
|
"en-ZA",
|
||||||
|
"en-TZ",
|
||||||
|
"en-GB",
|
||||||
|
"en-US",
|
||||||
|
"et-EE",
|
||||||
|
"fil-PH",
|
||||||
|
"fi-FI",
|
||||||
|
"fr-BE",
|
||||||
|
"fr-CA",
|
||||||
|
"fr-FR",
|
||||||
|
"fr-CH",
|
||||||
|
"gl-ES",
|
||||||
|
"ka-GE",
|
||||||
|
"de-AT",
|
||||||
|
"de-DE",
|
||||||
|
"de-CH",
|
||||||
|
"el-GR",
|
||||||
|
"gu-IN",
|
||||||
|
"iw-IL",
|
||||||
|
"hi-IN",
|
||||||
|
"hu-HU",
|
||||||
|
"is-IS",
|
||||||
|
"id-ID",
|
||||||
|
"it-IT",
|
||||||
|
"it-CH",
|
||||||
|
"ja-JP",
|
||||||
|
"jv-ID",
|
||||||
|
"kn-IN",
|
||||||
|
"kk-KZ",
|
||||||
|
"km-KH",
|
||||||
|
"ko-KR",
|
||||||
|
"lo-LA",
|
||||||
|
"lv-LV",
|
||||||
|
"lt-LT",
|
||||||
|
"mk-MK",
|
||||||
|
"ms-MY",
|
||||||
|
"ml-IN",
|
||||||
|
"mr-IN",
|
||||||
|
"mn-MN",
|
||||||
|
"ne-NP",
|
||||||
|
"no-NO",
|
||||||
|
"fa-IR",
|
||||||
|
"pl-PL",
|
||||||
|
"pt-BR",
|
||||||
|
"pt-PT",
|
||||||
|
"ro-RO",
|
||||||
|
"ru-RU",
|
||||||
|
"sr-RS",
|
||||||
|
"si-LK",
|
||||||
|
"sk-SK",
|
||||||
|
"sl-SI",
|
||||||
|
"es-AR",
|
||||||
|
"es-BO",
|
||||||
|
"es-CL",
|
||||||
|
"es-CO",
|
||||||
|
"es-CR",
|
||||||
|
"es-DO",
|
||||||
|
"es-EC",
|
||||||
|
"es-SV",
|
||||||
|
"es-GT",
|
||||||
|
"es-HN",
|
||||||
|
"es-MX",
|
||||||
|
"es-NI",
|
||||||
|
"es-PA",
|
||||||
|
"es-PY",
|
||||||
|
"es-PE",
|
||||||
|
"es-PR",
|
||||||
|
"es-ES",
|
||||||
|
"es-US",
|
||||||
|
"es-UY",
|
||||||
|
"es-VE",
|
||||||
|
"su-ID",
|
||||||
|
"sw-KE",
|
||||||
|
"sw-TZ",
|
||||||
|
"sv-SE",
|
||||||
|
"ta-IN",
|
||||||
|
"ta-MY",
|
||||||
|
"ta-SG",
|
||||||
|
"ta-LK",
|
||||||
|
"te-IN",
|
||||||
|
"th-TH",
|
||||||
|
"tr-TR",
|
||||||
|
"uk-UA",
|
||||||
|
"ur-IN",
|
||||||
|
"ur-PK",
|
||||||
|
"uz-UZ",
|
||||||
|
"vi-VN",
|
||||||
|
"zu-ZA",
|
||||||
|
]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def supported_formats(self) -> list[stt.AudioFormats]:
|
||||||
|
"""Return a list of supported formats."""
|
||||||
|
# https://ai.google.dev/gemini-api/docs/audio#supported-formats
|
||||||
|
return [stt.AudioFormats.WAV, stt.AudioFormats.OGG]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def supported_codecs(self) -> list[stt.AudioCodecs]:
|
||||||
|
"""Return a list of supported codecs."""
|
||||||
|
return [stt.AudioCodecs.PCM, stt.AudioCodecs.OPUS]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def supported_bit_rates(self) -> list[stt.AudioBitRates]:
|
||||||
|
"""Return a list of supported bit rates."""
|
||||||
|
return [stt.AudioBitRates.BITRATE_16]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def supported_sample_rates(self) -> list[stt.AudioSampleRates]:
|
||||||
|
"""Return a list of supported sample rates."""
|
||||||
|
return [stt.AudioSampleRates.SAMPLERATE_16000]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def supported_channels(self) -> list[stt.AudioChannels]:
|
||||||
|
"""Return a list of supported channels."""
|
||||||
|
# Per https://ai.google.dev/gemini-api/docs/audio
|
||||||
|
# If the audio source contains multiple channels, Gemini combines those channels into a single channel.
|
||||||
|
return [stt.AudioChannels.CHANNEL_MONO]
|
||||||
|
|
||||||
|
async def async_process_audio_stream(
|
||||||
|
self, metadata: stt.SpeechMetadata, stream: AsyncIterable[bytes]
|
||||||
|
) -> stt.SpeechResult:
|
||||||
|
"""Process an audio stream to STT service."""
|
||||||
|
audio_data = b""
|
||||||
|
async for chunk in stream:
|
||||||
|
audio_data += chunk
|
||||||
|
if metadata.format == stt.AudioFormats.WAV:
|
||||||
|
audio_data = convert_to_wav(
|
||||||
|
audio_data,
|
||||||
|
f"audio/L{metadata.bit_rate.value};rate={metadata.sample_rate.value}",
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
response = await self._genai_client.aio.models.generate_content(
|
||||||
|
model=self.subentry.data.get(CONF_CHAT_MODEL, RECOMMENDED_STT_MODEL),
|
||||||
|
contents=[
|
||||||
|
self.subentry.data.get(CONF_PROMPT, DEFAULT_STT_PROMPT),
|
||||||
|
Part.from_bytes(
|
||||||
|
data=audio_data,
|
||||||
|
mime_type=f"audio/{metadata.format.value}",
|
||||||
|
),
|
||||||
|
],
|
||||||
|
config=self.create_generate_content_config(),
|
||||||
|
)
|
||||||
|
except (APIError, ClientError, ValueError) as err:
|
||||||
|
LOGGER.error("Error during STT: %s", err)
|
||||||
|
else:
|
||||||
|
if response.text:
|
||||||
|
return stt.SpeechResult(
|
||||||
|
response.text,
|
||||||
|
stt.SpeechResultState.SUCCESS,
|
||||||
|
)
|
||||||
|
|
||||||
|
return stt.SpeechResult(None, stt.SpeechResultState.ERROR)
|
@ -9,6 +9,7 @@ from homeassistant.components.google_generative_ai_conversation.const import (
|
|||||||
CONF_USE_GOOGLE_SEARCH_TOOL,
|
CONF_USE_GOOGLE_SEARCH_TOOL,
|
||||||
DEFAULT_AI_TASK_NAME,
|
DEFAULT_AI_TASK_NAME,
|
||||||
DEFAULT_CONVERSATION_NAME,
|
DEFAULT_CONVERSATION_NAME,
|
||||||
|
DEFAULT_STT_NAME,
|
||||||
DEFAULT_TTS_NAME,
|
DEFAULT_TTS_NAME,
|
||||||
)
|
)
|
||||||
from homeassistant.config_entries import ConfigEntry
|
from homeassistant.config_entries import ConfigEntry
|
||||||
@ -39,6 +40,13 @@ def mock_config_entry(hass: HomeAssistant) -> MockConfigEntry:
|
|||||||
"subentry_id": "ulid-conversation",
|
"subentry_id": "ulid-conversation",
|
||||||
"unique_id": None,
|
"unique_id": None,
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
"data": {},
|
||||||
|
"subentry_type": "stt",
|
||||||
|
"title": DEFAULT_STT_NAME,
|
||||||
|
"subentry_id": "ulid-stt",
|
||||||
|
"unique_id": None,
|
||||||
|
},
|
||||||
{
|
{
|
||||||
"data": {},
|
"data": {},
|
||||||
"subentry_type": "tts",
|
"subentry_type": "tts",
|
||||||
|
@ -34,6 +34,14 @@
|
|||||||
'title': 'Google AI Conversation',
|
'title': 'Google AI Conversation',
|
||||||
'unique_id': None,
|
'unique_id': None,
|
||||||
}),
|
}),
|
||||||
|
'ulid-stt': dict({
|
||||||
|
'data': dict({
|
||||||
|
}),
|
||||||
|
'subentry_id': 'ulid-stt',
|
||||||
|
'subentry_type': 'stt',
|
||||||
|
'title': 'Google AI STT',
|
||||||
|
'unique_id': None,
|
||||||
|
}),
|
||||||
'ulid-tts': dict({
|
'ulid-tts': dict({
|
||||||
'data': dict({
|
'data': dict({
|
||||||
}),
|
}),
|
||||||
|
@ -32,6 +32,37 @@
|
|||||||
'sw_version': None,
|
'sw_version': None,
|
||||||
'via_device_id': None,
|
'via_device_id': None,
|
||||||
}),
|
}),
|
||||||
|
DeviceRegistryEntrySnapshot({
|
||||||
|
'area_id': None,
|
||||||
|
'config_entries': <ANY>,
|
||||||
|
'config_entries_subentries': <ANY>,
|
||||||
|
'configuration_url': None,
|
||||||
|
'connections': set({
|
||||||
|
}),
|
||||||
|
'disabled_by': None,
|
||||||
|
'entry_type': <DeviceEntryType.SERVICE: 'service'>,
|
||||||
|
'hw_version': None,
|
||||||
|
'id': <ANY>,
|
||||||
|
'identifiers': set({
|
||||||
|
tuple(
|
||||||
|
'google_generative_ai_conversation',
|
||||||
|
'ulid-stt',
|
||||||
|
),
|
||||||
|
}),
|
||||||
|
'is_new': False,
|
||||||
|
'labels': set({
|
||||||
|
}),
|
||||||
|
'manufacturer': 'Google',
|
||||||
|
'model': 'gemini-2.5-flash',
|
||||||
|
'model_id': None,
|
||||||
|
'name': 'Google AI STT',
|
||||||
|
'name_by_user': None,
|
||||||
|
'primary_config_entry': <ANY>,
|
||||||
|
'serial_number': None,
|
||||||
|
'suggested_area': None,
|
||||||
|
'sw_version': None,
|
||||||
|
'via_device_id': None,
|
||||||
|
}),
|
||||||
DeviceRegistryEntrySnapshot({
|
DeviceRegistryEntrySnapshot({
|
||||||
'area_id': None,
|
'area_id': None,
|
||||||
'config_entries': <ANY>,
|
'config_entries': <ANY>,
|
||||||
|
@ -1,5 +1,6 @@
|
|||||||
"""Test the Google Generative AI Conversation config flow."""
|
"""Test the Google Generative AI Conversation config flow."""
|
||||||
|
|
||||||
|
from typing import Any
|
||||||
from unittest.mock import Mock, patch
|
from unittest.mock import Mock, patch
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
@ -21,6 +22,7 @@ from homeassistant.components.google_generative_ai_conversation.const import (
|
|||||||
CONF_USE_GOOGLE_SEARCH_TOOL,
|
CONF_USE_GOOGLE_SEARCH_TOOL,
|
||||||
DEFAULT_AI_TASK_NAME,
|
DEFAULT_AI_TASK_NAME,
|
||||||
DEFAULT_CONVERSATION_NAME,
|
DEFAULT_CONVERSATION_NAME,
|
||||||
|
DEFAULT_STT_NAME,
|
||||||
DEFAULT_TTS_NAME,
|
DEFAULT_TTS_NAME,
|
||||||
DOMAIN,
|
DOMAIN,
|
||||||
RECOMMENDED_AI_TASK_OPTIONS,
|
RECOMMENDED_AI_TASK_OPTIONS,
|
||||||
@ -28,8 +30,11 @@ from homeassistant.components.google_generative_ai_conversation.const import (
|
|||||||
RECOMMENDED_CONVERSATION_OPTIONS,
|
RECOMMENDED_CONVERSATION_OPTIONS,
|
||||||
RECOMMENDED_HARM_BLOCK_THRESHOLD,
|
RECOMMENDED_HARM_BLOCK_THRESHOLD,
|
||||||
RECOMMENDED_MAX_TOKENS,
|
RECOMMENDED_MAX_TOKENS,
|
||||||
|
RECOMMENDED_STT_MODEL,
|
||||||
|
RECOMMENDED_STT_OPTIONS,
|
||||||
RECOMMENDED_TOP_K,
|
RECOMMENDED_TOP_K,
|
||||||
RECOMMENDED_TOP_P,
|
RECOMMENDED_TOP_P,
|
||||||
|
RECOMMENDED_TTS_MODEL,
|
||||||
RECOMMENDED_TTS_OPTIONS,
|
RECOMMENDED_TTS_OPTIONS,
|
||||||
RECOMMENDED_USE_GOOGLE_SEARCH_TOOL,
|
RECOMMENDED_USE_GOOGLE_SEARCH_TOOL,
|
||||||
)
|
)
|
||||||
@ -64,11 +69,17 @@ def get_models_pager():
|
|||||||
)
|
)
|
||||||
model_15_pro.name = "models/gemini-1.5-pro-latest"
|
model_15_pro.name = "models/gemini-1.5-pro-latest"
|
||||||
|
|
||||||
|
model_25_flash_tts = Mock(
|
||||||
|
supported_actions=["generateContent"],
|
||||||
|
)
|
||||||
|
model_25_flash_tts.name = "models/gemini-2.5-flash-preview-tts"
|
||||||
|
|
||||||
async def models_pager():
|
async def models_pager():
|
||||||
yield model_25_flash
|
yield model_25_flash
|
||||||
yield model_20_flash
|
yield model_20_flash
|
||||||
yield model_15_flash
|
yield model_15_flash
|
||||||
yield model_15_pro
|
yield model_15_pro
|
||||||
|
yield model_25_flash_tts
|
||||||
|
|
||||||
return models_pager()
|
return models_pager()
|
||||||
|
|
||||||
@ -129,6 +140,12 @@ async def test_form(hass: HomeAssistant) -> None:
|
|||||||
"title": DEFAULT_AI_TASK_NAME,
|
"title": DEFAULT_AI_TASK_NAME,
|
||||||
"unique_id": None,
|
"unique_id": None,
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
"subentry_type": "stt",
|
||||||
|
"data": RECOMMENDED_STT_OPTIONS,
|
||||||
|
"title": DEFAULT_STT_NAME,
|
||||||
|
"unique_id": None,
|
||||||
|
},
|
||||||
]
|
]
|
||||||
assert len(mock_setup_entry.mock_calls) == 1
|
assert len(mock_setup_entry.mock_calls) == 1
|
||||||
|
|
||||||
@ -157,22 +174,35 @@ async def test_duplicate_entry(hass: HomeAssistant) -> None:
|
|||||||
assert result["reason"] == "already_configured"
|
assert result["reason"] == "already_configured"
|
||||||
|
|
||||||
|
|
||||||
async def test_creating_conversation_subentry(
|
@pytest.mark.parametrize(
|
||||||
|
("subentry_type", "options"),
|
||||||
|
[
|
||||||
|
("conversation", RECOMMENDED_CONVERSATION_OPTIONS),
|
||||||
|
("stt", RECOMMENDED_STT_OPTIONS),
|
||||||
|
("tts", RECOMMENDED_TTS_OPTIONS),
|
||||||
|
("ai_task_data", RECOMMENDED_AI_TASK_OPTIONS),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
async def test_creating_subentry(
|
||||||
hass: HomeAssistant,
|
hass: HomeAssistant,
|
||||||
mock_init_component: None,
|
mock_init_component: None,
|
||||||
mock_config_entry: MockConfigEntry,
|
mock_config_entry: MockConfigEntry,
|
||||||
|
subentry_type: str,
|
||||||
|
options: dict[str, Any],
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Test creating a conversation subentry."""
|
"""Test creating a subentry."""
|
||||||
|
old_subentries = set(mock_config_entry.subentries)
|
||||||
|
|
||||||
with patch(
|
with patch(
|
||||||
"google.genai.models.AsyncModels.list",
|
"google.genai.models.AsyncModels.list",
|
||||||
return_value=get_models_pager(),
|
return_value=get_models_pager(),
|
||||||
):
|
):
|
||||||
result = await hass.config_entries.subentries.async_init(
|
result = await hass.config_entries.subentries.async_init(
|
||||||
(mock_config_entry.entry_id, "conversation"),
|
(mock_config_entry.entry_id, subentry_type),
|
||||||
context={"source": config_entries.SOURCE_USER},
|
context={"source": config_entries.SOURCE_USER},
|
||||||
)
|
)
|
||||||
|
|
||||||
assert result["type"] is FlowResultType.FORM
|
assert result["type"] is FlowResultType.FORM, result
|
||||||
assert result["step_id"] == "set_options"
|
assert result["step_id"] == "set_options"
|
||||||
assert not result["errors"]
|
assert not result["errors"]
|
||||||
|
|
||||||
@ -182,31 +212,117 @@ async def test_creating_conversation_subentry(
|
|||||||
):
|
):
|
||||||
result2 = await hass.config_entries.subentries.async_configure(
|
result2 = await hass.config_entries.subentries.async_configure(
|
||||||
result["flow_id"],
|
result["flow_id"],
|
||||||
{CONF_NAME: "Mock name", **RECOMMENDED_CONVERSATION_OPTIONS},
|
result["data_schema"]({CONF_NAME: "Mock name", **options}),
|
||||||
)
|
)
|
||||||
await hass.async_block_till_done()
|
await hass.async_block_till_done()
|
||||||
|
|
||||||
|
expected_options = options.copy()
|
||||||
|
if CONF_PROMPT in expected_options:
|
||||||
|
expected_options[CONF_PROMPT] = expected_options[CONF_PROMPT].strip()
|
||||||
assert result2["type"] is FlowResultType.CREATE_ENTRY
|
assert result2["type"] is FlowResultType.CREATE_ENTRY
|
||||||
assert result2["title"] == "Mock name"
|
assert result2["title"] == "Mock name"
|
||||||
|
assert result2["data"] == expected_options
|
||||||
|
|
||||||
processed_options = RECOMMENDED_CONVERSATION_OPTIONS.copy()
|
assert len(mock_config_entry.subentries) == len(old_subentries) + 1
|
||||||
processed_options[CONF_PROMPT] = processed_options[CONF_PROMPT].strip()
|
|
||||||
|
|
||||||
assert result2["data"] == processed_options
|
new_subentry_id = list(set(mock_config_entry.subentries) - old_subentries)[0]
|
||||||
|
new_subentry = mock_config_entry.subentries[new_subentry_id]
|
||||||
|
|
||||||
|
assert new_subentry.subentry_type == subentry_type
|
||||||
|
assert new_subentry.data == expected_options
|
||||||
|
assert new_subentry.title == "Mock name"
|
||||||
|
|
||||||
|
|
||||||
async def test_creating_tts_subentry(
|
@pytest.mark.parametrize(
|
||||||
|
("subentry_type", "recommended_model", "options"),
|
||||||
|
[
|
||||||
|
(
|
||||||
|
"conversation",
|
||||||
|
RECOMMENDED_CHAT_MODEL,
|
||||||
|
{
|
||||||
|
CONF_PROMPT: "You are Mario",
|
||||||
|
CONF_LLM_HASS_API: ["assist"],
|
||||||
|
CONF_RECOMMENDED: False,
|
||||||
|
CONF_CHAT_MODEL: RECOMMENDED_CHAT_MODEL,
|
||||||
|
CONF_TEMPERATURE: 1.0,
|
||||||
|
CONF_TOP_P: 1.0,
|
||||||
|
CONF_TOP_K: 1,
|
||||||
|
CONF_MAX_TOKENS: 1024,
|
||||||
|
CONF_HARASSMENT_BLOCK_THRESHOLD: "BLOCK_MEDIUM_AND_ABOVE",
|
||||||
|
CONF_HATE_BLOCK_THRESHOLD: "BLOCK_MEDIUM_AND_ABOVE",
|
||||||
|
CONF_SEXUAL_BLOCK_THRESHOLD: "BLOCK_MEDIUM_AND_ABOVE",
|
||||||
|
CONF_DANGEROUS_BLOCK_THRESHOLD: "BLOCK_MEDIUM_AND_ABOVE",
|
||||||
|
CONF_USE_GOOGLE_SEARCH_TOOL: RECOMMENDED_USE_GOOGLE_SEARCH_TOOL,
|
||||||
|
},
|
||||||
|
),
|
||||||
|
(
|
||||||
|
"stt",
|
||||||
|
RECOMMENDED_STT_MODEL,
|
||||||
|
{
|
||||||
|
CONF_PROMPT: "Transcribe this",
|
||||||
|
CONF_RECOMMENDED: False,
|
||||||
|
CONF_CHAT_MODEL: RECOMMENDED_STT_MODEL,
|
||||||
|
CONF_TEMPERATURE: 1.0,
|
||||||
|
CONF_TOP_P: 1.0,
|
||||||
|
CONF_TOP_K: 1,
|
||||||
|
CONF_MAX_TOKENS: 1024,
|
||||||
|
CONF_HARASSMENT_BLOCK_THRESHOLD: "BLOCK_MEDIUM_AND_ABOVE",
|
||||||
|
CONF_HATE_BLOCK_THRESHOLD: "BLOCK_MEDIUM_AND_ABOVE",
|
||||||
|
CONF_SEXUAL_BLOCK_THRESHOLD: "BLOCK_MEDIUM_AND_ABOVE",
|
||||||
|
CONF_DANGEROUS_BLOCK_THRESHOLD: "BLOCK_MEDIUM_AND_ABOVE",
|
||||||
|
},
|
||||||
|
),
|
||||||
|
(
|
||||||
|
"tts",
|
||||||
|
RECOMMENDED_TTS_MODEL,
|
||||||
|
{
|
||||||
|
CONF_RECOMMENDED: False,
|
||||||
|
CONF_CHAT_MODEL: RECOMMENDED_TTS_MODEL,
|
||||||
|
CONF_TEMPERATURE: 1.0,
|
||||||
|
CONF_TOP_P: 1.0,
|
||||||
|
CONF_TOP_K: 1,
|
||||||
|
CONF_MAX_TOKENS: 1024,
|
||||||
|
CONF_HARASSMENT_BLOCK_THRESHOLD: "BLOCK_MEDIUM_AND_ABOVE",
|
||||||
|
CONF_HATE_BLOCK_THRESHOLD: "BLOCK_MEDIUM_AND_ABOVE",
|
||||||
|
CONF_SEXUAL_BLOCK_THRESHOLD: "BLOCK_MEDIUM_AND_ABOVE",
|
||||||
|
CONF_DANGEROUS_BLOCK_THRESHOLD: "BLOCK_MEDIUM_AND_ABOVE",
|
||||||
|
},
|
||||||
|
),
|
||||||
|
(
|
||||||
|
"ai_task_data",
|
||||||
|
RECOMMENDED_CHAT_MODEL,
|
||||||
|
{
|
||||||
|
CONF_RECOMMENDED: False,
|
||||||
|
CONF_CHAT_MODEL: RECOMMENDED_CHAT_MODEL,
|
||||||
|
CONF_TEMPERATURE: 1.0,
|
||||||
|
CONF_TOP_P: 1.0,
|
||||||
|
CONF_TOP_K: 1,
|
||||||
|
CONF_MAX_TOKENS: 1024,
|
||||||
|
CONF_HARASSMENT_BLOCK_THRESHOLD: "BLOCK_MEDIUM_AND_ABOVE",
|
||||||
|
CONF_HATE_BLOCK_THRESHOLD: "BLOCK_MEDIUM_AND_ABOVE",
|
||||||
|
CONF_SEXUAL_BLOCK_THRESHOLD: "BLOCK_MEDIUM_AND_ABOVE",
|
||||||
|
CONF_DANGEROUS_BLOCK_THRESHOLD: "BLOCK_MEDIUM_AND_ABOVE",
|
||||||
|
},
|
||||||
|
),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
async def test_creating_subentry_custom_options(
|
||||||
hass: HomeAssistant,
|
hass: HomeAssistant,
|
||||||
mock_init_component: None,
|
mock_init_component: None,
|
||||||
mock_config_entry: MockConfigEntry,
|
mock_config_entry: MockConfigEntry,
|
||||||
|
subentry_type: str,
|
||||||
|
recommended_model: str,
|
||||||
|
options: dict[str, Any],
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Test creating a TTS subentry."""
|
"""Test creating a subentry with custom options."""
|
||||||
|
old_subentries = set(mock_config_entry.subentries)
|
||||||
|
|
||||||
with patch(
|
with patch(
|
||||||
"google.genai.models.AsyncModels.list",
|
"google.genai.models.AsyncModels.list",
|
||||||
return_value=get_models_pager(),
|
return_value=get_models_pager(),
|
||||||
):
|
):
|
||||||
result = await hass.config_entries.subentries.async_init(
|
result = await hass.config_entries.subentries.async_init(
|
||||||
(mock_config_entry.entry_id, "tts"),
|
(mock_config_entry.entry_id, subentry_type),
|
||||||
context={"source": config_entries.SOURCE_USER},
|
context={"source": config_entries.SOURCE_USER},
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -214,75 +330,52 @@ async def test_creating_tts_subentry(
|
|||||||
assert result["step_id"] == "set_options"
|
assert result["step_id"] == "set_options"
|
||||||
assert not result["errors"]
|
assert not result["errors"]
|
||||||
|
|
||||||
old_subentries = set(mock_config_entry.subentries)
|
# Uncheck recommended to show custom options
|
||||||
|
|
||||||
with patch(
|
with patch(
|
||||||
"google.genai.models.AsyncModels.list",
|
"google.genai.models.AsyncModels.list",
|
||||||
return_value=get_models_pager(),
|
return_value=get_models_pager(),
|
||||||
):
|
):
|
||||||
result2 = await hass.config_entries.subentries.async_configure(
|
result2 = await hass.config_entries.subentries.async_configure(
|
||||||
result["flow_id"],
|
result["flow_id"],
|
||||||
{CONF_NAME: "Mock TTS", **RECOMMENDED_TTS_OPTIONS},
|
result["data_schema"]({CONF_RECOMMENDED: False}),
|
||||||
)
|
)
|
||||||
await hass.async_block_till_done()
|
assert result2["type"] is FlowResultType.FORM
|
||||||
|
|
||||||
assert result2["type"] is FlowResultType.CREATE_ENTRY
|
# Find the schema key for CONF_CHAT_MODEL and check its default
|
||||||
assert result2["title"] == "Mock TTS"
|
schema_dict = result2["data_schema"].schema
|
||||||
assert result2["data"] == RECOMMENDED_TTS_OPTIONS
|
chat_model_key = next(key for key in schema_dict if key.schema == CONF_CHAT_MODEL)
|
||||||
|
assert chat_model_key.default() == recommended_model
|
||||||
|
models_in_selector = [
|
||||||
|
opt["value"] for opt in schema_dict[chat_model_key].config["options"]
|
||||||
|
]
|
||||||
|
assert recommended_model in models_in_selector
|
||||||
|
|
||||||
assert len(mock_config_entry.subentries) == 4
|
# Submit the form
|
||||||
|
|
||||||
new_subentry_id = list(set(mock_config_entry.subentries) - old_subentries)[0]
|
|
||||||
new_subentry = mock_config_entry.subentries[new_subentry_id]
|
|
||||||
|
|
||||||
assert new_subentry.subentry_type == "tts"
|
|
||||||
assert new_subentry.data == RECOMMENDED_TTS_OPTIONS
|
|
||||||
assert new_subentry.title == "Mock TTS"
|
|
||||||
|
|
||||||
|
|
||||||
async def test_creating_ai_task_subentry(
|
|
||||||
hass: HomeAssistant,
|
|
||||||
mock_init_component: None,
|
|
||||||
mock_config_entry: MockConfigEntry,
|
|
||||||
) -> None:
|
|
||||||
"""Test creating an AI task subentry."""
|
|
||||||
with patch(
|
with patch(
|
||||||
"google.genai.models.AsyncModels.list",
|
"google.genai.models.AsyncModels.list",
|
||||||
return_value=get_models_pager(),
|
return_value=get_models_pager(),
|
||||||
):
|
):
|
||||||
result = await hass.config_entries.subentries.async_init(
|
result3 = await hass.config_entries.subentries.async_configure(
|
||||||
(mock_config_entry.entry_id, "ai_task_data"),
|
result2["flow_id"],
|
||||||
context={"source": config_entries.SOURCE_USER},
|
result2["data_schema"]({CONF_NAME: "Mock name", **options}),
|
||||||
)
|
|
||||||
|
|
||||||
assert result["type"] is FlowResultType.FORM, result
|
|
||||||
assert result["step_id"] == "set_options"
|
|
||||||
assert not result["errors"]
|
|
||||||
|
|
||||||
old_subentries = set(mock_config_entry.subentries)
|
|
||||||
|
|
||||||
with patch(
|
|
||||||
"google.genai.models.AsyncModels.list",
|
|
||||||
return_value=get_models_pager(),
|
|
||||||
):
|
|
||||||
result2 = await hass.config_entries.subentries.async_configure(
|
|
||||||
result["flow_id"],
|
|
||||||
{CONF_NAME: "Mock AI Task", **RECOMMENDED_AI_TASK_OPTIONS},
|
|
||||||
)
|
)
|
||||||
await hass.async_block_till_done()
|
await hass.async_block_till_done()
|
||||||
|
|
||||||
assert result2["type"] is FlowResultType.CREATE_ENTRY
|
expected_options = options.copy()
|
||||||
assert result2["title"] == "Mock AI Task"
|
if CONF_PROMPT in expected_options:
|
||||||
assert result2["data"] == RECOMMENDED_AI_TASK_OPTIONS
|
expected_options[CONF_PROMPT] = expected_options[CONF_PROMPT].strip()
|
||||||
|
assert result3["type"] is FlowResultType.CREATE_ENTRY
|
||||||
|
assert result3["title"] == "Mock name"
|
||||||
|
assert result3["data"] == expected_options
|
||||||
|
|
||||||
assert len(mock_config_entry.subentries) == 4
|
assert len(mock_config_entry.subentries) == len(old_subentries) + 1
|
||||||
|
|
||||||
new_subentry_id = list(set(mock_config_entry.subentries) - old_subentries)[0]
|
new_subentry_id = list(set(mock_config_entry.subentries) - old_subentries)[0]
|
||||||
new_subentry = mock_config_entry.subentries[new_subentry_id]
|
new_subentry = mock_config_entry.subentries[new_subentry_id]
|
||||||
|
|
||||||
assert new_subentry.subentry_type == "ai_task_data"
|
assert new_subentry.subentry_type == subentry_type
|
||||||
assert new_subentry.data == RECOMMENDED_AI_TASK_OPTIONS
|
assert new_subentry.data == expected_options
|
||||||
assert new_subentry.title == "Mock AI Task"
|
assert new_subentry.title == "Mock name"
|
||||||
|
|
||||||
|
|
||||||
async def test_creating_conversation_subentry_not_loaded(
|
async def test_creating_conversation_subentry_not_loaded(
|
||||||
|
@ -11,11 +11,13 @@ from syrupy.assertion import SnapshotAssertion
|
|||||||
from homeassistant.components.google_generative_ai_conversation.const import (
|
from homeassistant.components.google_generative_ai_conversation.const import (
|
||||||
DEFAULT_AI_TASK_NAME,
|
DEFAULT_AI_TASK_NAME,
|
||||||
DEFAULT_CONVERSATION_NAME,
|
DEFAULT_CONVERSATION_NAME,
|
||||||
|
DEFAULT_STT_NAME,
|
||||||
DEFAULT_TITLE,
|
DEFAULT_TITLE,
|
||||||
DEFAULT_TTS_NAME,
|
DEFAULT_TTS_NAME,
|
||||||
DOMAIN,
|
DOMAIN,
|
||||||
RECOMMENDED_AI_TASK_OPTIONS,
|
RECOMMENDED_AI_TASK_OPTIONS,
|
||||||
RECOMMENDED_CONVERSATION_OPTIONS,
|
RECOMMENDED_CONVERSATION_OPTIONS,
|
||||||
|
RECOMMENDED_STT_OPTIONS,
|
||||||
RECOMMENDED_TTS_OPTIONS,
|
RECOMMENDED_TTS_OPTIONS,
|
||||||
)
|
)
|
||||||
from homeassistant.config_entries import (
|
from homeassistant.config_entries import (
|
||||||
@ -489,7 +491,7 @@ async def test_migration_from_v1(
|
|||||||
assert entry.minor_version == 4
|
assert entry.minor_version == 4
|
||||||
assert not entry.options
|
assert not entry.options
|
||||||
assert entry.title == DEFAULT_TITLE
|
assert entry.title == DEFAULT_TITLE
|
||||||
assert len(entry.subentries) == 4
|
assert len(entry.subentries) == 5
|
||||||
conversation_subentries = [
|
conversation_subentries = [
|
||||||
subentry
|
subentry
|
||||||
for subentry in entry.subentries.values()
|
for subentry in entry.subentries.values()
|
||||||
@ -516,6 +518,14 @@ async def test_migration_from_v1(
|
|||||||
assert len(ai_task_subentries) == 1
|
assert len(ai_task_subentries) == 1
|
||||||
assert ai_task_subentries[0].data == RECOMMENDED_AI_TASK_OPTIONS
|
assert ai_task_subentries[0].data == RECOMMENDED_AI_TASK_OPTIONS
|
||||||
assert ai_task_subentries[0].title == DEFAULT_AI_TASK_NAME
|
assert ai_task_subentries[0].title == DEFAULT_AI_TASK_NAME
|
||||||
|
stt_subentries = [
|
||||||
|
subentry
|
||||||
|
for subentry in entry.subentries.values()
|
||||||
|
if subentry.subentry_type == "stt"
|
||||||
|
]
|
||||||
|
assert len(stt_subentries) == 1
|
||||||
|
assert stt_subentries[0].data == RECOMMENDED_STT_OPTIONS
|
||||||
|
assert stt_subentries[0].title == DEFAULT_STT_NAME
|
||||||
|
|
||||||
subentry = conversation_subentries[0]
|
subentry = conversation_subentries[0]
|
||||||
|
|
||||||
@ -721,7 +731,7 @@ async def test_migration_from_v1_disabled(
|
|||||||
assert entry.minor_version == 4
|
assert entry.minor_version == 4
|
||||||
assert not entry.options
|
assert not entry.options
|
||||||
assert entry.title == DEFAULT_TITLE
|
assert entry.title == DEFAULT_TITLE
|
||||||
assert len(entry.subentries) == 4
|
assert len(entry.subentries) == 5
|
||||||
conversation_subentries = [
|
conversation_subentries = [
|
||||||
subentry
|
subentry
|
||||||
for subentry in entry.subentries.values()
|
for subentry in entry.subentries.values()
|
||||||
@ -748,6 +758,14 @@ async def test_migration_from_v1_disabled(
|
|||||||
assert len(ai_task_subentries) == 1
|
assert len(ai_task_subentries) == 1
|
||||||
assert ai_task_subentries[0].data == RECOMMENDED_AI_TASK_OPTIONS
|
assert ai_task_subentries[0].data == RECOMMENDED_AI_TASK_OPTIONS
|
||||||
assert ai_task_subentries[0].title == DEFAULT_AI_TASK_NAME
|
assert ai_task_subentries[0].title == DEFAULT_AI_TASK_NAME
|
||||||
|
stt_subentries = [
|
||||||
|
subentry
|
||||||
|
for subentry in entry.subentries.values()
|
||||||
|
if subentry.subentry_type == "stt"
|
||||||
|
]
|
||||||
|
assert len(stt_subentries) == 1
|
||||||
|
assert stt_subentries[0].data == RECOMMENDED_STT_OPTIONS
|
||||||
|
assert stt_subentries[0].title == DEFAULT_STT_NAME
|
||||||
|
|
||||||
assert not device_registry.async_get_device(
|
assert not device_registry.async_get_device(
|
||||||
identifiers={(DOMAIN, mock_config_entry.entry_id)}
|
identifiers={(DOMAIN, mock_config_entry.entry_id)}
|
||||||
@ -860,7 +878,7 @@ async def test_migration_from_v1_with_multiple_keys(
|
|||||||
assert entry.minor_version == 4
|
assert entry.minor_version == 4
|
||||||
assert not entry.options
|
assert not entry.options
|
||||||
assert entry.title == DEFAULT_TITLE
|
assert entry.title == DEFAULT_TITLE
|
||||||
assert len(entry.subentries) == 3
|
assert len(entry.subentries) == 4
|
||||||
subentry = list(entry.subentries.values())[0]
|
subentry = list(entry.subentries.values())[0]
|
||||||
assert subentry.subentry_type == "conversation"
|
assert subentry.subentry_type == "conversation"
|
||||||
assert subentry.data == options
|
assert subentry.data == options
|
||||||
@ -873,6 +891,10 @@ async def test_migration_from_v1_with_multiple_keys(
|
|||||||
assert subentry.subentry_type == "ai_task_data"
|
assert subentry.subentry_type == "ai_task_data"
|
||||||
assert subentry.data == RECOMMENDED_AI_TASK_OPTIONS
|
assert subentry.data == RECOMMENDED_AI_TASK_OPTIONS
|
||||||
assert subentry.title == DEFAULT_AI_TASK_NAME
|
assert subentry.title == DEFAULT_AI_TASK_NAME
|
||||||
|
subentry = list(entry.subentries.values())[3]
|
||||||
|
assert subentry.subentry_type == "stt"
|
||||||
|
assert subentry.data == RECOMMENDED_STT_OPTIONS
|
||||||
|
assert subentry.title == DEFAULT_STT_NAME
|
||||||
|
|
||||||
dev = device_registry.async_get_device(
|
dev = device_registry.async_get_device(
|
||||||
identifiers={(DOMAIN, list(entry.subentries.values())[0].subentry_id)}
|
identifiers={(DOMAIN, list(entry.subentries.values())[0].subentry_id)}
|
||||||
@ -963,7 +985,7 @@ async def test_migration_from_v1_with_same_keys(
|
|||||||
assert entry.minor_version == 4
|
assert entry.minor_version == 4
|
||||||
assert not entry.options
|
assert not entry.options
|
||||||
assert entry.title == DEFAULT_TITLE
|
assert entry.title == DEFAULT_TITLE
|
||||||
assert len(entry.subentries) == 4
|
assert len(entry.subentries) == 5
|
||||||
conversation_subentries = [
|
conversation_subentries = [
|
||||||
subentry
|
subentry
|
||||||
for subentry in entry.subentries.values()
|
for subentry in entry.subentries.values()
|
||||||
@ -990,6 +1012,14 @@ async def test_migration_from_v1_with_same_keys(
|
|||||||
assert len(ai_task_subentries) == 1
|
assert len(ai_task_subentries) == 1
|
||||||
assert ai_task_subentries[0].data == RECOMMENDED_AI_TASK_OPTIONS
|
assert ai_task_subentries[0].data == RECOMMENDED_AI_TASK_OPTIONS
|
||||||
assert ai_task_subentries[0].title == DEFAULT_AI_TASK_NAME
|
assert ai_task_subentries[0].title == DEFAULT_AI_TASK_NAME
|
||||||
|
stt_subentries = [
|
||||||
|
subentry
|
||||||
|
for subentry in entry.subentries.values()
|
||||||
|
if subentry.subentry_type == "stt"
|
||||||
|
]
|
||||||
|
assert len(stt_subentries) == 1
|
||||||
|
assert stt_subentries[0].data == RECOMMENDED_STT_OPTIONS
|
||||||
|
assert stt_subentries[0].title == DEFAULT_STT_NAME
|
||||||
|
|
||||||
subentry = conversation_subentries[0]
|
subentry = conversation_subentries[0]
|
||||||
|
|
||||||
@ -1090,10 +1120,11 @@ async def test_migration_from_v2_1(
|
|||||||
"""Test migration from version 2.1.
|
"""Test migration from version 2.1.
|
||||||
|
|
||||||
This tests we clean up the broken migration in Home Assistant Core
|
This tests we clean up the broken migration in Home Assistant Core
|
||||||
2025.7.0b0-2025.7.0b1 and add AI Task subentry:
|
2025.7.0b0-2025.7.0b1 and add AI Task and STT subentries:
|
||||||
- Fix device registry (Fixed in Home Assistant Core 2025.7.0b2)
|
- Fix device registry (Fixed in Home Assistant Core 2025.7.0b2)
|
||||||
- Add TTS subentry (Added in Home Assistant Core 2025.7.0b1)
|
- Add TTS subentry (Added in Home Assistant Core 2025.7.0b1)
|
||||||
- Add AI Task subentry (Added in version 2.3)
|
- Add AI Task subentry (Added in version 2.3)
|
||||||
|
- Add STT subentry (Added in version 2.3)
|
||||||
"""
|
"""
|
||||||
# Create a v2.1 config entry with 2 subentries, devices and entities
|
# Create a v2.1 config entry with 2 subentries, devices and entities
|
||||||
options = {
|
options = {
|
||||||
@ -1184,7 +1215,7 @@ async def test_migration_from_v2_1(
|
|||||||
assert entry.minor_version == 4
|
assert entry.minor_version == 4
|
||||||
assert not entry.options
|
assert not entry.options
|
||||||
assert entry.title == DEFAULT_TITLE
|
assert entry.title == DEFAULT_TITLE
|
||||||
assert len(entry.subentries) == 4
|
assert len(entry.subentries) == 5
|
||||||
conversation_subentries = [
|
conversation_subentries = [
|
||||||
subentry
|
subentry
|
||||||
for subentry in entry.subentries.values()
|
for subentry in entry.subentries.values()
|
||||||
@ -1211,6 +1242,14 @@ async def test_migration_from_v2_1(
|
|||||||
assert len(ai_task_subentries) == 1
|
assert len(ai_task_subentries) == 1
|
||||||
assert ai_task_subentries[0].data == RECOMMENDED_AI_TASK_OPTIONS
|
assert ai_task_subentries[0].data == RECOMMENDED_AI_TASK_OPTIONS
|
||||||
assert ai_task_subentries[0].title == DEFAULT_AI_TASK_NAME
|
assert ai_task_subentries[0].title == DEFAULT_AI_TASK_NAME
|
||||||
|
stt_subentries = [
|
||||||
|
subentry
|
||||||
|
for subentry in entry.subentries.values()
|
||||||
|
if subentry.subentry_type == "stt"
|
||||||
|
]
|
||||||
|
assert len(stt_subentries) == 1
|
||||||
|
assert stt_subentries[0].data == RECOMMENDED_STT_OPTIONS
|
||||||
|
assert stt_subentries[0].title == DEFAULT_STT_NAME
|
||||||
|
|
||||||
subentry = conversation_subentries[0]
|
subentry = conversation_subentries[0]
|
||||||
|
|
||||||
@ -1320,8 +1359,8 @@ async def test_migrate_entry_from_v2_2(hass: HomeAssistant) -> None:
|
|||||||
assert entry.version == 2
|
assert entry.version == 2
|
||||||
assert entry.minor_version == 4
|
assert entry.minor_version == 4
|
||||||
|
|
||||||
# Check we now have conversation, tts and ai_task_data subentries
|
# Check we now have conversation, tts, stt, and ai_task_data subentries
|
||||||
assert len(entry.subentries) == 3
|
assert len(entry.subentries) == 4
|
||||||
|
|
||||||
subentries = {
|
subentries = {
|
||||||
subentry.subentry_type: subentry for subentry in entry.subentries.values()
|
subentry.subentry_type: subentry for subentry in entry.subentries.values()
|
||||||
@ -1336,6 +1375,12 @@ async def test_migrate_entry_from_v2_2(hass: HomeAssistant) -> None:
|
|||||||
assert ai_task_subentry.title == DEFAULT_AI_TASK_NAME
|
assert ai_task_subentry.title == DEFAULT_AI_TASK_NAME
|
||||||
assert ai_task_subentry.data == RECOMMENDED_AI_TASK_OPTIONS
|
assert ai_task_subentry.data == RECOMMENDED_AI_TASK_OPTIONS
|
||||||
|
|
||||||
|
# Find and verify the stt subentry
|
||||||
|
ai_task_subentry = subentries["stt"]
|
||||||
|
assert ai_task_subentry is not None
|
||||||
|
assert ai_task_subentry.title == DEFAULT_STT_NAME
|
||||||
|
assert ai_task_subentry.data == RECOMMENDED_STT_OPTIONS
|
||||||
|
|
||||||
# Verify conversation subentry is still there and unchanged
|
# Verify conversation subentry is still there and unchanged
|
||||||
conversation_subentry = subentries["conversation"]
|
conversation_subentry = subentries["conversation"]
|
||||||
assert conversation_subentry is not None
|
assert conversation_subentry is not None
|
||||||
|
303
tests/components/google_generative_ai_conversation/test_stt.py
Normal file
303
tests/components/google_generative_ai_conversation/test_stt.py
Normal file
@ -0,0 +1,303 @@
|
|||||||
|
"""Tests for the Google Generative AI Conversation STT entity."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from collections.abc import AsyncIterable, Generator
|
||||||
|
from unittest.mock import AsyncMock, Mock, patch
|
||||||
|
|
||||||
|
from google.genai import types
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from homeassistant.components import stt
|
||||||
|
from homeassistant.components.google_generative_ai_conversation.const import (
|
||||||
|
CONF_CHAT_MODEL,
|
||||||
|
CONF_PROMPT,
|
||||||
|
DEFAULT_STT_PROMPT,
|
||||||
|
DOMAIN,
|
||||||
|
RECOMMENDED_STT_MODEL,
|
||||||
|
)
|
||||||
|
from homeassistant.config_entries import ConfigSubentry
|
||||||
|
from homeassistant.const import CONF_API_KEY
|
||||||
|
from homeassistant.core import HomeAssistant
|
||||||
|
|
||||||
|
from . import API_ERROR_500, CLIENT_ERROR_BAD_REQUEST
|
||||||
|
|
||||||
|
from tests.common import MockConfigEntry
|
||||||
|
|
||||||
|
TEST_CHAT_MODEL = "models/gemini-2.5-flash"
|
||||||
|
TEST_PROMPT = "Please transcribe the audio."
|
||||||
|
|
||||||
|
|
||||||
|
async def _async_get_audio_stream(data: bytes) -> AsyncIterable[bytes]:
|
||||||
|
"""Yield the audio data."""
|
||||||
|
yield data
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_genai_client() -> Generator[AsyncMock]:
|
||||||
|
"""Mock genai.Client."""
|
||||||
|
client = Mock()
|
||||||
|
client.aio.models.get = AsyncMock()
|
||||||
|
client.aio.models.generate_content = AsyncMock(
|
||||||
|
return_value=types.GenerateContentResponse(
|
||||||
|
candidates=[
|
||||||
|
{
|
||||||
|
"content": {
|
||||||
|
"parts": [{"text": "This is a test transcription."}],
|
||||||
|
"role": "model",
|
||||||
|
}
|
||||||
|
}
|
||||||
|
]
|
||||||
|
)
|
||||||
|
)
|
||||||
|
with patch(
|
||||||
|
"homeassistant.components.google_generative_ai_conversation.Client",
|
||||||
|
return_value=client,
|
||||||
|
) as mock_client:
|
||||||
|
yield mock_client.return_value
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
async def setup_integration(
|
||||||
|
hass: HomeAssistant,
|
||||||
|
mock_genai_client: AsyncMock,
|
||||||
|
) -> None:
|
||||||
|
"""Set up the test environment."""
|
||||||
|
config_entry = MockConfigEntry(
|
||||||
|
domain=DOMAIN, data={CONF_API_KEY: "bla"}, version=2, minor_version=1
|
||||||
|
)
|
||||||
|
config_entry.add_to_hass(hass)
|
||||||
|
|
||||||
|
sub_entry = ConfigSubentry(
|
||||||
|
data={
|
||||||
|
CONF_CHAT_MODEL: TEST_CHAT_MODEL,
|
||||||
|
CONF_PROMPT: TEST_PROMPT,
|
||||||
|
},
|
||||||
|
subentry_type="stt",
|
||||||
|
title="Google AI STT",
|
||||||
|
unique_id=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
config_entry.runtime_data = mock_genai_client
|
||||||
|
|
||||||
|
hass.config_entries.async_add_subentry(config_entry, sub_entry)
|
||||||
|
await hass.config_entries.async_setup(config_entry.entry_id)
|
||||||
|
await hass.async_block_till_done()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.usefixtures("setup_integration")
|
||||||
|
async def test_stt_entity_properties(hass: HomeAssistant) -> None:
|
||||||
|
"""Test STT entity properties."""
|
||||||
|
entity: stt.SpeechToTextEntity = hass.data[stt.DOMAIN].get_entity(
|
||||||
|
"stt.google_ai_stt"
|
||||||
|
)
|
||||||
|
assert entity is not None
|
||||||
|
assert isinstance(entity.supported_languages, list)
|
||||||
|
assert stt.AudioFormats.WAV in entity.supported_formats
|
||||||
|
assert stt.AudioFormats.OGG in entity.supported_formats
|
||||||
|
assert stt.AudioCodecs.PCM in entity.supported_codecs
|
||||||
|
assert stt.AudioCodecs.OPUS in entity.supported_codecs
|
||||||
|
assert stt.AudioBitRates.BITRATE_16 in entity.supported_bit_rates
|
||||||
|
assert stt.AudioSampleRates.SAMPLERATE_16000 in entity.supported_sample_rates
|
||||||
|
assert stt.AudioChannels.CHANNEL_MONO in entity.supported_channels
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
("audio_format", "call_convert_to_wav"),
|
||||||
|
[
|
||||||
|
(stt.AudioFormats.WAV, True),
|
||||||
|
(stt.AudioFormats.OGG, False),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
@pytest.mark.usefixtures("setup_integration")
|
||||||
|
async def test_stt_process_audio_stream_success(
|
||||||
|
hass: HomeAssistant,
|
||||||
|
mock_genai_client: AsyncMock,
|
||||||
|
audio_format: stt.AudioFormats,
|
||||||
|
call_convert_to_wav: bool,
|
||||||
|
) -> None:
|
||||||
|
"""Test STT processing audio stream successfully."""
|
||||||
|
entity = hass.data[stt.DOMAIN].get_entity("stt.google_ai_stt")
|
||||||
|
|
||||||
|
metadata = stt.SpeechMetadata(
|
||||||
|
language="en-US",
|
||||||
|
format=audio_format,
|
||||||
|
codec=stt.AudioCodecs.PCM,
|
||||||
|
bit_rate=stt.AudioBitRates.BITRATE_16,
|
||||||
|
sample_rate=stt.AudioSampleRates.SAMPLERATE_16000,
|
||||||
|
channel=stt.AudioChannels.CHANNEL_MONO,
|
||||||
|
)
|
||||||
|
audio_stream = _async_get_audio_stream(b"test_audio_bytes")
|
||||||
|
|
||||||
|
with patch(
|
||||||
|
"homeassistant.components.google_generative_ai_conversation.stt.convert_to_wav",
|
||||||
|
return_value=b"converted_wav_bytes",
|
||||||
|
) as mock_convert_to_wav:
|
||||||
|
result = await entity.async_process_audio_stream(metadata, audio_stream)
|
||||||
|
|
||||||
|
assert result.result == stt.SpeechResultState.SUCCESS
|
||||||
|
assert result.text == "This is a test transcription."
|
||||||
|
|
||||||
|
if call_convert_to_wav:
|
||||||
|
mock_convert_to_wav.assert_called_once_with(
|
||||||
|
b"test_audio_bytes", "audio/L16;rate=16000"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
mock_convert_to_wav.assert_not_called()
|
||||||
|
|
||||||
|
mock_genai_client.aio.models.generate_content.assert_called_once()
|
||||||
|
call_args = mock_genai_client.aio.models.generate_content.call_args
|
||||||
|
assert call_args.kwargs["model"] == TEST_CHAT_MODEL
|
||||||
|
|
||||||
|
contents = call_args.kwargs["contents"]
|
||||||
|
assert contents[0] == TEST_PROMPT
|
||||||
|
assert isinstance(contents[1], types.Part)
|
||||||
|
assert contents[1].inline_data.mime_type == f"audio/{audio_format.value}"
|
||||||
|
if call_convert_to_wav:
|
||||||
|
assert contents[1].inline_data.data == b"converted_wav_bytes"
|
||||||
|
else:
|
||||||
|
assert contents[1].inline_data.data == b"test_audio_bytes"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"side_effect",
|
||||||
|
[
|
||||||
|
API_ERROR_500,
|
||||||
|
CLIENT_ERROR_BAD_REQUEST,
|
||||||
|
ValueError("Test value error"),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
@pytest.mark.usefixtures("setup_integration")
|
||||||
|
async def test_stt_process_audio_stream_api_error(
|
||||||
|
hass: HomeAssistant,
|
||||||
|
mock_genai_client: AsyncMock,
|
||||||
|
side_effect: Exception,
|
||||||
|
) -> None:
|
||||||
|
"""Test STT processing audio stream with API errors."""
|
||||||
|
entity = hass.data[stt.DOMAIN].get_entity("stt.google_ai_stt")
|
||||||
|
mock_genai_client.aio.models.generate_content.side_effect = side_effect
|
||||||
|
|
||||||
|
metadata = stt.SpeechMetadata(
|
||||||
|
language="en-US",
|
||||||
|
format=stt.AudioFormats.OGG,
|
||||||
|
codec=stt.AudioCodecs.OPUS,
|
||||||
|
bit_rate=stt.AudioBitRates.BITRATE_16,
|
||||||
|
sample_rate=stt.AudioSampleRates.SAMPLERATE_16000,
|
||||||
|
channel=stt.AudioChannels.CHANNEL_MONO,
|
||||||
|
)
|
||||||
|
audio_stream = _async_get_audio_stream(b"test_audio_bytes")
|
||||||
|
|
||||||
|
result = await entity.async_process_audio_stream(metadata, audio_stream)
|
||||||
|
|
||||||
|
assert result.result == stt.SpeechResultState.ERROR
|
||||||
|
assert result.text is None
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.usefixtures("setup_integration")
|
||||||
|
async def test_stt_process_audio_stream_empty_response(
|
||||||
|
hass: HomeAssistant,
|
||||||
|
mock_genai_client: AsyncMock,
|
||||||
|
) -> None:
|
||||||
|
"""Test STT processing with an empty response from the API."""
|
||||||
|
entity = hass.data[stt.DOMAIN].get_entity("stt.google_ai_stt")
|
||||||
|
mock_genai_client.aio.models.generate_content.return_value = (
|
||||||
|
types.GenerateContentResponse(candidates=[])
|
||||||
|
)
|
||||||
|
|
||||||
|
metadata = stt.SpeechMetadata(
|
||||||
|
language="en-US",
|
||||||
|
format=stt.AudioFormats.OGG,
|
||||||
|
codec=stt.AudioCodecs.OPUS,
|
||||||
|
bit_rate=stt.AudioBitRates.BITRATE_16,
|
||||||
|
sample_rate=stt.AudioSampleRates.SAMPLERATE_16000,
|
||||||
|
channel=stt.AudioChannels.CHANNEL_MONO,
|
||||||
|
)
|
||||||
|
audio_stream = _async_get_audio_stream(b"test_audio_bytes")
|
||||||
|
|
||||||
|
result = await entity.async_process_audio_stream(metadata, audio_stream)
|
||||||
|
|
||||||
|
assert result.result == stt.SpeechResultState.ERROR
|
||||||
|
assert result.text is None
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.usefixtures("mock_genai_client")
|
||||||
|
async def test_stt_uses_default_prompt(
|
||||||
|
hass: HomeAssistant,
|
||||||
|
mock_genai_client: AsyncMock,
|
||||||
|
) -> None:
|
||||||
|
"""Test that the default prompt is used if none is configured."""
|
||||||
|
config_entry = MockConfigEntry(
|
||||||
|
domain=DOMAIN, data={CONF_API_KEY: "bla"}, version=2, minor_version=1
|
||||||
|
)
|
||||||
|
config_entry.add_to_hass(hass)
|
||||||
|
config_entry.runtime_data = mock_genai_client
|
||||||
|
|
||||||
|
# Subentry with no prompt
|
||||||
|
sub_entry = ConfigSubentry(
|
||||||
|
data={CONF_CHAT_MODEL: TEST_CHAT_MODEL},
|
||||||
|
subentry_type="stt",
|
||||||
|
title="Google AI STT",
|
||||||
|
unique_id=None,
|
||||||
|
)
|
||||||
|
hass.config_entries.async_add_subentry(config_entry, sub_entry)
|
||||||
|
await hass.config_entries.async_setup(config_entry.entry_id)
|
||||||
|
await hass.async_block_till_done()
|
||||||
|
|
||||||
|
entity = hass.data[stt.DOMAIN].get_entity("stt.google_ai_stt")
|
||||||
|
|
||||||
|
metadata = stt.SpeechMetadata(
|
||||||
|
language="en-US",
|
||||||
|
format=stt.AudioFormats.OGG,
|
||||||
|
codec=stt.AudioCodecs.OPUS,
|
||||||
|
bit_rate=stt.AudioBitRates.BITRATE_16,
|
||||||
|
sample_rate=stt.AudioSampleRates.SAMPLERATE_16000,
|
||||||
|
channel=stt.AudioChannels.CHANNEL_MONO,
|
||||||
|
)
|
||||||
|
audio_stream = _async_get_audio_stream(b"test_audio_bytes")
|
||||||
|
|
||||||
|
await entity.async_process_audio_stream(metadata, audio_stream)
|
||||||
|
|
||||||
|
call_args = mock_genai_client.aio.models.generate_content.call_args
|
||||||
|
contents = call_args.kwargs["contents"]
|
||||||
|
assert contents[0] == DEFAULT_STT_PROMPT
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.usefixtures("mock_genai_client")
|
||||||
|
async def test_stt_uses_default_model(
|
||||||
|
hass: HomeAssistant,
|
||||||
|
mock_genai_client: AsyncMock,
|
||||||
|
) -> None:
|
||||||
|
"""Test that the default model is used if none is configured."""
|
||||||
|
config_entry = MockConfigEntry(
|
||||||
|
domain=DOMAIN, data={CONF_API_KEY: "bla"}, version=2, minor_version=1
|
||||||
|
)
|
||||||
|
config_entry.add_to_hass(hass)
|
||||||
|
config_entry.runtime_data = mock_genai_client
|
||||||
|
|
||||||
|
# Subentry with no model
|
||||||
|
sub_entry = ConfigSubentry(
|
||||||
|
data={CONF_PROMPT: TEST_PROMPT},
|
||||||
|
subentry_type="stt",
|
||||||
|
title="Google AI STT",
|
||||||
|
unique_id=None,
|
||||||
|
)
|
||||||
|
hass.config_entries.async_add_subentry(config_entry, sub_entry)
|
||||||
|
await hass.config_entries.async_setup(config_entry.entry_id)
|
||||||
|
await hass.async_block_till_done()
|
||||||
|
|
||||||
|
entity = hass.data[stt.DOMAIN].get_entity("stt.google_ai_stt")
|
||||||
|
|
||||||
|
metadata = stt.SpeechMetadata(
|
||||||
|
language="en-US",
|
||||||
|
format=stt.AudioFormats.OGG,
|
||||||
|
codec=stt.AudioCodecs.OPUS,
|
||||||
|
bit_rate=stt.AudioBitRates.BITRATE_16,
|
||||||
|
sample_rate=stt.AudioSampleRates.SAMPLERATE_16000,
|
||||||
|
channel=stt.AudioChannels.CHANNEL_MONO,
|
||||||
|
)
|
||||||
|
audio_stream = _async_get_audio_stream(b"test_audio_bytes")
|
||||||
|
|
||||||
|
await entity.async_process_audio_stream(metadata, audio_stream)
|
||||||
|
|
||||||
|
call_args = mock_genai_client.aio.models.generate_content.call_args
|
||||||
|
assert call_args.kwargs["model"] == RECOMMENDED_STT_MODEL
|
Loading…
x
Reference in New Issue
Block a user