Refactor the validation in Google Cloud TTS (#120853)

This commit is contained in:
tronikos 2024-07-06 02:44:46 -07:00 committed by GitHub
parent 8ff4991f07
commit 17daccd38a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 170 additions and 132 deletions

View File

@ -0,0 +1,16 @@
"""Constants for the Google Cloud component."""
from __future__ import annotations
CONF_KEY_FILE = "key_file"
DEFAULT_LANG = "en-US"
CONF_GENDER = "gender"
CONF_VOICE = "voice"
CONF_ENCODING = "encoding"
CONF_SPEED = "speed"
CONF_PITCH = "pitch"
CONF_GAIN = "gain"
CONF_PROFILES = "profiles"
CONF_TEXT_TYPE = "text_type"

View File

@ -2,7 +2,36 @@
from __future__ import annotations
from types import MappingProxyType
from typing import Any
from google.cloud import texttospeech
import voluptuous as vol
from homeassistant.components.tts import CONF_LANG
import homeassistant.helpers.config_validation as cv
from homeassistant.helpers.selector import (
NumberSelector,
NumberSelectorConfig,
SelectSelector,
SelectSelectorConfig,
SelectSelectorMode,
)
from .const import (
CONF_ENCODING,
CONF_GAIN,
CONF_GENDER,
CONF_KEY_FILE,
CONF_PITCH,
CONF_PROFILES,
CONF_SPEED,
CONF_TEXT_TYPE,
CONF_VOICE,
DEFAULT_LANG,
)
DEFAULT_VOICE = ""
async def async_tts_voices(
@ -17,3 +46,106 @@ async def async_tts_voices(
voices[language_code] = []
voices[language_code].append(voice.name)
return voices
def tts_options_schema(
config_options: MappingProxyType[str, Any], voices: dict[str, list[str]]
):
"""Return schema for TTS options with default values from config or constants."""
return vol.Schema(
{
vol.Optional(
CONF_GENDER,
description={"suggested_value": config_options.get(CONF_GENDER)},
default=texttospeech.SsmlVoiceGender.NEUTRAL.name, # type: ignore[attr-defined]
): SelectSelector(
SelectSelectorConfig(
mode=SelectSelectorMode.DROPDOWN,
options=list(texttospeech.SsmlVoiceGender.__members__),
)
),
vol.Optional(
CONF_VOICE,
description={"suggested_value": config_options.get(CONF_VOICE)},
default=DEFAULT_VOICE,
): SelectSelector(
SelectSelectorConfig(
mode=SelectSelectorMode.DROPDOWN,
options=["", *sum(voices.values(), [])],
)
),
vol.Optional(
CONF_ENCODING,
description={"suggested_value": config_options.get(CONF_ENCODING)},
default=texttospeech.AudioEncoding.MP3.name, # type: ignore[attr-defined]
): SelectSelector(
SelectSelectorConfig(
mode=SelectSelectorMode.DROPDOWN,
options=list(texttospeech.AudioEncoding.__members__),
)
),
vol.Optional(
CONF_SPEED,
description={"suggested_value": config_options.get(CONF_SPEED)},
default=1.0,
): NumberSelector(NumberSelectorConfig(min=0.25, max=4.0, step=0.01)),
vol.Optional(
CONF_PITCH,
description={"suggested_value": config_options.get(CONF_PITCH)},
default=0,
): NumberSelector(NumberSelectorConfig(min=-20.0, max=20.0, step=0.1)),
vol.Optional(
CONF_GAIN,
description={"suggested_value": config_options.get(CONF_GAIN)},
default=0,
): NumberSelector(NumberSelectorConfig(min=-96.0, max=16.0, step=0.1)),
vol.Optional(
CONF_PROFILES,
description={"suggested_value": config_options.get(CONF_PROFILES)},
default=[],
): SelectSelector(
SelectSelectorConfig(
mode=SelectSelectorMode.DROPDOWN,
options=[
# https://cloud.google.com/text-to-speech/docs/audio-profiles
"wearable-class-device",
"handset-class-device",
"headphone-class-device",
"small-bluetooth-speaker-class-device",
"medium-bluetooth-speaker-class-device",
"large-home-entertainment-class-device",
"large-automotive-class-device",
"telephony-class-application",
],
multiple=True,
sort=False,
)
),
vol.Optional(
CONF_TEXT_TYPE,
description={"suggested_value": config_options.get(CONF_TEXT_TYPE)},
default="text",
): SelectSelector(
SelectSelectorConfig(
mode=SelectSelectorMode.DROPDOWN,
options=["text", "ssml"],
)
),
}
)
def tts_platform_schema():
"""Return schema for TTS platform."""
return vol.Schema(
{
vol.Optional(CONF_KEY_FILE): cv.string,
vol.Optional(CONF_LANG, default=DEFAULT_LANG): cv.matches_regex(
r"[a-z]{2,3}-[A-Z]{2}|"
),
**tts_options_schema({}, {}).schema,
vol.Optional(CONF_VOICE, default=DEFAULT_VOICE): cv.matches_regex(
r"[a-z]{2,3}-[A-Z]{2}-.*-[A-Z]|"
),
}
)

View File

@ -14,92 +14,24 @@ from homeassistant.components.tts import (
Voice,
)
from homeassistant.core import HomeAssistant, callback
import homeassistant.helpers.config_validation as cv
from .helpers import async_tts_voices
from .const import (
CONF_ENCODING,
CONF_GAIN,
CONF_GENDER,
CONF_KEY_FILE,
CONF_PITCH,
CONF_PROFILES,
CONF_SPEED,
CONF_TEXT_TYPE,
CONF_VOICE,
DEFAULT_LANG,
)
from .helpers import async_tts_voices, tts_options_schema, tts_platform_schema
_LOGGER = logging.getLogger(__name__)
CONF_KEY_FILE = "key_file"
CONF_GENDER = "gender"
CONF_VOICE = "voice"
CONF_ENCODING = "encoding"
CONF_SPEED = "speed"
CONF_PITCH = "pitch"
CONF_GAIN = "gain"
CONF_PROFILES = "profiles"
CONF_TEXT_TYPE = "text_type"
DEFAULT_LANG = "en-US"
DEFAULT_GENDER = "NEUTRAL"
LANG_REGEX = r"[a-z]{2,3}-[A-Z]{2}|"
VOICE_REGEX = r"[a-z]{2,3}-[A-Z]{2}-.*-[A-Z]|"
DEFAULT_VOICE = ""
DEFAULT_ENCODING = "MP3"
MIN_SPEED = 0.25
MAX_SPEED = 4.0
DEFAULT_SPEED = 1.0
MIN_PITCH = -20.0
MAX_PITCH = 20.0
DEFAULT_PITCH = 0
MIN_GAIN = -96.0
MAX_GAIN = 16.0
DEFAULT_GAIN = 0
SUPPORTED_TEXT_TYPES = ["text", "ssml"]
DEFAULT_TEXT_TYPE = "text"
SUPPORTED_PROFILES = [
"wearable-class-device",
"handset-class-device",
"headphone-class-device",
"small-bluetooth-speaker-class-device",
"medium-bluetooth-speaker-class-device",
"large-home-entertainment-class-device",
"large-automotive-class-device",
"telephony-class-application",
]
SUPPORTED_OPTIONS = [
CONF_VOICE,
CONF_GENDER,
CONF_ENCODING,
CONF_SPEED,
CONF_PITCH,
CONF_GAIN,
CONF_PROFILES,
CONF_TEXT_TYPE,
]
GENDER_SCHEMA = vol.All(vol.Upper, vol.In(texttospeech.SsmlVoiceGender.__members__))
VOICE_SCHEMA = cv.matches_regex(VOICE_REGEX)
SCHEMA_ENCODING = vol.All(vol.Upper, vol.In(texttospeech.AudioEncoding.__members__))
SPEED_SCHEMA = vol.All(vol.Coerce(float), vol.Clamp(min=MIN_SPEED, max=MAX_SPEED))
PITCH_SCHEMA = vol.All(vol.Coerce(float), vol.Clamp(min=MIN_PITCH, max=MAX_PITCH))
GAIN_SCHEMA = vol.All(vol.Coerce(float), vol.Clamp(min=MIN_GAIN, max=MAX_GAIN))
PROFILES_SCHEMA = vol.All(cv.ensure_list, [vol.In(SUPPORTED_PROFILES)])
TEXT_TYPE_SCHEMA = vol.All(vol.Lower, vol.In(SUPPORTED_TEXT_TYPES))
PLATFORM_SCHEMA = TTS_PLATFORM_SCHEMA.extend(
{
vol.Optional(CONF_KEY_FILE): cv.string,
vol.Optional(CONF_LANG, default=DEFAULT_LANG): cv.matches_regex(LANG_REGEX),
vol.Optional(CONF_GENDER, default=DEFAULT_GENDER): GENDER_SCHEMA,
vol.Optional(CONF_VOICE, default=DEFAULT_VOICE): VOICE_SCHEMA,
vol.Optional(CONF_ENCODING, default=DEFAULT_ENCODING): SCHEMA_ENCODING,
vol.Optional(CONF_SPEED, default=DEFAULT_SPEED): SPEED_SCHEMA,
vol.Optional(CONF_PITCH, default=DEFAULT_PITCH): PITCH_SCHEMA,
vol.Optional(CONF_GAIN, default=DEFAULT_GAIN): GAIN_SCHEMA,
vol.Optional(CONF_PROFILES, default=[]): PROFILES_SCHEMA,
vol.Optional(CONF_TEXT_TYPE, default=DEFAULT_TEXT_TYPE): TEXT_TYPE_SCHEMA,
}
)
PLATFORM_SCHEMA = TTS_PLATFORM_SCHEMA.extend(tts_platform_schema().schema)
async def async_get_engine(hass, config, discovery_info=None):
@ -124,15 +56,8 @@ async def async_get_engine(hass, config, discovery_info=None):
hass,
client,
voices,
config[CONF_LANG],
config[CONF_GENDER],
config[CONF_VOICE],
config[CONF_ENCODING],
config[CONF_SPEED],
config[CONF_PITCH],
config[CONF_GAIN],
config[CONF_PROFILES],
config[CONF_TEXT_TYPE],
config.get(CONF_LANG, DEFAULT_LANG),
tts_options_schema(config, voices),
)
@ -144,15 +69,8 @@ class GoogleCloudTTSProvider(Provider):
hass: HomeAssistant,
client: texttospeech.TextToSpeechAsyncClient,
voices: dict[str, list[str]],
language=DEFAULT_LANG,
gender=DEFAULT_GENDER,
voice=DEFAULT_VOICE,
encoding=DEFAULT_ENCODING,
speed=1.0,
pitch=0,
gain=0,
profiles=None,
text_type=DEFAULT_TEXT_TYPE,
language,
options_schema,
) -> None:
"""Init Google Cloud TTS service."""
self.hass = hass
@ -160,14 +78,7 @@ class GoogleCloudTTSProvider(Provider):
self._client = client
self._voices = voices
self._language = language
self._gender = gender
self._voice = voice
self._encoding = encoding
self._speed = speed
self._pitch = pitch
self._gain = gain
self._profiles = profiles
self._text_type = text_type
self._options_schema = options_schema
@property
def supported_languages(self):
@ -182,21 +93,12 @@ class GoogleCloudTTSProvider(Provider):
@property
def supported_options(self):
"""Return a list of supported options."""
return SUPPORTED_OPTIONS
return [option.schema for option in self._options_schema.schema]
@property
def default_options(self):
"""Return a dict including default options."""
return {
CONF_GENDER: self._gender,
CONF_VOICE: self._voice,
CONF_ENCODING: self._encoding,
CONF_SPEED: self._speed,
CONF_PITCH: self._pitch,
CONF_GAIN: self._gain,
CONF_PROFILES: self._profiles,
CONF_TEXT_TYPE: self._text_type,
}
return self._options_schema({})
@callback
def async_get_supported_voices(self, language: str) -> list[Voice] | None:
@ -207,20 +109,8 @@ class GoogleCloudTTSProvider(Provider):
async def async_get_tts_audio(self, message, language, options):
"""Load TTS from google."""
options_schema = vol.Schema(
{
vol.Optional(CONF_GENDER, default=self._gender): GENDER_SCHEMA,
vol.Optional(CONF_VOICE, default=self._voice): VOICE_SCHEMA,
vol.Optional(CONF_ENCODING, default=self._encoding): SCHEMA_ENCODING,
vol.Optional(CONF_SPEED, default=self._speed): SPEED_SCHEMA,
vol.Optional(CONF_PITCH, default=self._pitch): PITCH_SCHEMA,
vol.Optional(CONF_GAIN, default=self._gain): GAIN_SCHEMA,
vol.Optional(CONF_PROFILES, default=self._profiles): PROFILES_SCHEMA,
vol.Optional(CONF_TEXT_TYPE, default=self._text_type): TEXT_TYPE_SCHEMA,
}
)
try:
options = options_schema(options)
options = self._options_schema(options)
except vol.Invalid as err:
_LOGGER.error("Error: %s when validating options: %s", err, options)
return None, None