From 17daccd38afd162d4ac6c7005033e1239011e5c6 Mon Sep 17 00:00:00 2001 From: tronikos Date: Sat, 6 Jul 2024 02:44:46 -0700 Subject: [PATCH] Refactor the validation in Google Cloud TTS (#120853) --- .../components/google_cloud/const.py | 16 ++ .../components/google_cloud/helpers.py | 132 +++++++++++++++ homeassistant/components/google_cloud/tts.py | 154 +++--------------- 3 files changed, 170 insertions(+), 132 deletions(-) create mode 100644 homeassistant/components/google_cloud/const.py diff --git a/homeassistant/components/google_cloud/const.py b/homeassistant/components/google_cloud/const.py new file mode 100644 index 00000000000..0fbd5e78274 --- /dev/null +++ b/homeassistant/components/google_cloud/const.py @@ -0,0 +1,16 @@ +"""Constants for the Google Cloud component.""" + +from __future__ import annotations + +CONF_KEY_FILE = "key_file" + +DEFAULT_LANG = "en-US" + +CONF_GENDER = "gender" +CONF_VOICE = "voice" +CONF_ENCODING = "encoding" +CONF_SPEED = "speed" +CONF_PITCH = "pitch" +CONF_GAIN = "gain" +CONF_PROFILES = "profiles" +CONF_TEXT_TYPE = "text_type" diff --git a/homeassistant/components/google_cloud/helpers.py b/homeassistant/components/google_cloud/helpers.py index 6a890f90cc7..39e26844dd2 100644 --- a/homeassistant/components/google_cloud/helpers.py +++ b/homeassistant/components/google_cloud/helpers.py @@ -2,7 +2,36 @@ from __future__ import annotations +from types import MappingProxyType +from typing import Any + from google.cloud import texttospeech +import voluptuous as vol + +from homeassistant.components.tts import CONF_LANG +import homeassistant.helpers.config_validation as cv +from homeassistant.helpers.selector import ( + NumberSelector, + NumberSelectorConfig, + SelectSelector, + SelectSelectorConfig, + SelectSelectorMode, +) + +from .const import ( + CONF_ENCODING, + CONF_GAIN, + CONF_GENDER, + CONF_KEY_FILE, + CONF_PITCH, + CONF_PROFILES, + CONF_SPEED, + CONF_TEXT_TYPE, + CONF_VOICE, + DEFAULT_LANG, +) + +DEFAULT_VOICE = "" async def async_tts_voices( @@ -17,3 +46,106 @@ async def async_tts_voices( voices[language_code] = [] voices[language_code].append(voice.name) return voices + + +def tts_options_schema( + config_options: MappingProxyType[str, Any], voices: dict[str, list[str]] +): + """Return schema for TTS options with default values from config or constants.""" + return vol.Schema( + { + vol.Optional( + CONF_GENDER, + description={"suggested_value": config_options.get(CONF_GENDER)}, + default=texttospeech.SsmlVoiceGender.NEUTRAL.name, # type: ignore[attr-defined] + ): SelectSelector( + SelectSelectorConfig( + mode=SelectSelectorMode.DROPDOWN, + options=list(texttospeech.SsmlVoiceGender.__members__), + ) + ), + vol.Optional( + CONF_VOICE, + description={"suggested_value": config_options.get(CONF_VOICE)}, + default=DEFAULT_VOICE, + ): SelectSelector( + SelectSelectorConfig( + mode=SelectSelectorMode.DROPDOWN, + options=["", *sum(voices.values(), [])], + ) + ), + vol.Optional( + CONF_ENCODING, + description={"suggested_value": config_options.get(CONF_ENCODING)}, + default=texttospeech.AudioEncoding.MP3.name, # type: ignore[attr-defined] + ): SelectSelector( + SelectSelectorConfig( + mode=SelectSelectorMode.DROPDOWN, + options=list(texttospeech.AudioEncoding.__members__), + ) + ), + vol.Optional( + CONF_SPEED, + description={"suggested_value": config_options.get(CONF_SPEED)}, + default=1.0, + ): NumberSelector(NumberSelectorConfig(min=0.25, max=4.0, step=0.01)), + vol.Optional( + CONF_PITCH, + description={"suggested_value": config_options.get(CONF_PITCH)}, + default=0, + ): NumberSelector(NumberSelectorConfig(min=-20.0, max=20.0, step=0.1)), + vol.Optional( + CONF_GAIN, + description={"suggested_value": config_options.get(CONF_GAIN)}, + default=0, + ): NumberSelector(NumberSelectorConfig(min=-96.0, max=16.0, step=0.1)), + vol.Optional( + CONF_PROFILES, + description={"suggested_value": config_options.get(CONF_PROFILES)}, + default=[], + ): SelectSelector( + SelectSelectorConfig( + mode=SelectSelectorMode.DROPDOWN, + options=[ + # https://cloud.google.com/text-to-speech/docs/audio-profiles + "wearable-class-device", + "handset-class-device", + "headphone-class-device", + "small-bluetooth-speaker-class-device", + "medium-bluetooth-speaker-class-device", + "large-home-entertainment-class-device", + "large-automotive-class-device", + "telephony-class-application", + ], + multiple=True, + sort=False, + ) + ), + vol.Optional( + CONF_TEXT_TYPE, + description={"suggested_value": config_options.get(CONF_TEXT_TYPE)}, + default="text", + ): SelectSelector( + SelectSelectorConfig( + mode=SelectSelectorMode.DROPDOWN, + options=["text", "ssml"], + ) + ), + } + ) + + +def tts_platform_schema(): + """Return schema for TTS platform.""" + return vol.Schema( + { + vol.Optional(CONF_KEY_FILE): cv.string, + vol.Optional(CONF_LANG, default=DEFAULT_LANG): cv.matches_regex( + r"[a-z]{2,3}-[A-Z]{2}|" + ), + **tts_options_schema({}, {}).schema, + vol.Optional(CONF_VOICE, default=DEFAULT_VOICE): cv.matches_regex( + r"[a-z]{2,3}-[A-Z]{2}-.*-[A-Z]|" + ), + } + ) diff --git a/homeassistant/components/google_cloud/tts.py b/homeassistant/components/google_cloud/tts.py index 92a8a6cdf5e..ee9999fc496 100644 --- a/homeassistant/components/google_cloud/tts.py +++ b/homeassistant/components/google_cloud/tts.py @@ -14,92 +14,24 @@ from homeassistant.components.tts import ( Voice, ) from homeassistant.core import HomeAssistant, callback -import homeassistant.helpers.config_validation as cv -from .helpers import async_tts_voices +from .const import ( + CONF_ENCODING, + CONF_GAIN, + CONF_GENDER, + CONF_KEY_FILE, + CONF_PITCH, + CONF_PROFILES, + CONF_SPEED, + CONF_TEXT_TYPE, + CONF_VOICE, + DEFAULT_LANG, +) +from .helpers import async_tts_voices, tts_options_schema, tts_platform_schema _LOGGER = logging.getLogger(__name__) -CONF_KEY_FILE = "key_file" -CONF_GENDER = "gender" -CONF_VOICE = "voice" -CONF_ENCODING = "encoding" -CONF_SPEED = "speed" -CONF_PITCH = "pitch" -CONF_GAIN = "gain" -CONF_PROFILES = "profiles" -CONF_TEXT_TYPE = "text_type" - -DEFAULT_LANG = "en-US" - -DEFAULT_GENDER = "NEUTRAL" - -LANG_REGEX = r"[a-z]{2,3}-[A-Z]{2}|" -VOICE_REGEX = r"[a-z]{2,3}-[A-Z]{2}-.*-[A-Z]|" -DEFAULT_VOICE = "" - -DEFAULT_ENCODING = "MP3" - -MIN_SPEED = 0.25 -MAX_SPEED = 4.0 -DEFAULT_SPEED = 1.0 - -MIN_PITCH = -20.0 -MAX_PITCH = 20.0 -DEFAULT_PITCH = 0 - -MIN_GAIN = -96.0 -MAX_GAIN = 16.0 -DEFAULT_GAIN = 0 - -SUPPORTED_TEXT_TYPES = ["text", "ssml"] -DEFAULT_TEXT_TYPE = "text" - -SUPPORTED_PROFILES = [ - "wearable-class-device", - "handset-class-device", - "headphone-class-device", - "small-bluetooth-speaker-class-device", - "medium-bluetooth-speaker-class-device", - "large-home-entertainment-class-device", - "large-automotive-class-device", - "telephony-class-application", -] - -SUPPORTED_OPTIONS = [ - CONF_VOICE, - CONF_GENDER, - CONF_ENCODING, - CONF_SPEED, - CONF_PITCH, - CONF_GAIN, - CONF_PROFILES, - CONF_TEXT_TYPE, -] - -GENDER_SCHEMA = vol.All(vol.Upper, vol.In(texttospeech.SsmlVoiceGender.__members__)) -VOICE_SCHEMA = cv.matches_regex(VOICE_REGEX) -SCHEMA_ENCODING = vol.All(vol.Upper, vol.In(texttospeech.AudioEncoding.__members__)) -SPEED_SCHEMA = vol.All(vol.Coerce(float), vol.Clamp(min=MIN_SPEED, max=MAX_SPEED)) -PITCH_SCHEMA = vol.All(vol.Coerce(float), vol.Clamp(min=MIN_PITCH, max=MAX_PITCH)) -GAIN_SCHEMA = vol.All(vol.Coerce(float), vol.Clamp(min=MIN_GAIN, max=MAX_GAIN)) -PROFILES_SCHEMA = vol.All(cv.ensure_list, [vol.In(SUPPORTED_PROFILES)]) -TEXT_TYPE_SCHEMA = vol.All(vol.Lower, vol.In(SUPPORTED_TEXT_TYPES)) - -PLATFORM_SCHEMA = TTS_PLATFORM_SCHEMA.extend( - { - vol.Optional(CONF_KEY_FILE): cv.string, - vol.Optional(CONF_LANG, default=DEFAULT_LANG): cv.matches_regex(LANG_REGEX), - vol.Optional(CONF_GENDER, default=DEFAULT_GENDER): GENDER_SCHEMA, - vol.Optional(CONF_VOICE, default=DEFAULT_VOICE): VOICE_SCHEMA, - vol.Optional(CONF_ENCODING, default=DEFAULT_ENCODING): SCHEMA_ENCODING, - vol.Optional(CONF_SPEED, default=DEFAULT_SPEED): SPEED_SCHEMA, - vol.Optional(CONF_PITCH, default=DEFAULT_PITCH): PITCH_SCHEMA, - vol.Optional(CONF_GAIN, default=DEFAULT_GAIN): GAIN_SCHEMA, - vol.Optional(CONF_PROFILES, default=[]): PROFILES_SCHEMA, - vol.Optional(CONF_TEXT_TYPE, default=DEFAULT_TEXT_TYPE): TEXT_TYPE_SCHEMA, - } -) +PLATFORM_SCHEMA = TTS_PLATFORM_SCHEMA.extend(tts_platform_schema().schema) async def async_get_engine(hass, config, discovery_info=None): @@ -124,15 +56,8 @@ async def async_get_engine(hass, config, discovery_info=None): hass, client, voices, - config[CONF_LANG], - config[CONF_GENDER], - config[CONF_VOICE], - config[CONF_ENCODING], - config[CONF_SPEED], - config[CONF_PITCH], - config[CONF_GAIN], - config[CONF_PROFILES], - config[CONF_TEXT_TYPE], + config.get(CONF_LANG, DEFAULT_LANG), + tts_options_schema(config, voices), ) @@ -144,15 +69,8 @@ class GoogleCloudTTSProvider(Provider): hass: HomeAssistant, client: texttospeech.TextToSpeechAsyncClient, voices: dict[str, list[str]], - language=DEFAULT_LANG, - gender=DEFAULT_GENDER, - voice=DEFAULT_VOICE, - encoding=DEFAULT_ENCODING, - speed=1.0, - pitch=0, - gain=0, - profiles=None, - text_type=DEFAULT_TEXT_TYPE, + language, + options_schema, ) -> None: """Init Google Cloud TTS service.""" self.hass = hass @@ -160,14 +78,7 @@ class GoogleCloudTTSProvider(Provider): self._client = client self._voices = voices self._language = language - self._gender = gender - self._voice = voice - self._encoding = encoding - self._speed = speed - self._pitch = pitch - self._gain = gain - self._profiles = profiles - self._text_type = text_type + self._options_schema = options_schema @property def supported_languages(self): @@ -182,21 +93,12 @@ class GoogleCloudTTSProvider(Provider): @property def supported_options(self): """Return a list of supported options.""" - return SUPPORTED_OPTIONS + return [option.schema for option in self._options_schema.schema] @property def default_options(self): """Return a dict including default options.""" - return { - CONF_GENDER: self._gender, - CONF_VOICE: self._voice, - CONF_ENCODING: self._encoding, - CONF_SPEED: self._speed, - CONF_PITCH: self._pitch, - CONF_GAIN: self._gain, - CONF_PROFILES: self._profiles, - CONF_TEXT_TYPE: self._text_type, - } + return self._options_schema({}) @callback def async_get_supported_voices(self, language: str) -> list[Voice] | None: @@ -207,20 +109,8 @@ class GoogleCloudTTSProvider(Provider): async def async_get_tts_audio(self, message, language, options): """Load TTS from google.""" - options_schema = vol.Schema( - { - vol.Optional(CONF_GENDER, default=self._gender): GENDER_SCHEMA, - vol.Optional(CONF_VOICE, default=self._voice): VOICE_SCHEMA, - vol.Optional(CONF_ENCODING, default=self._encoding): SCHEMA_ENCODING, - vol.Optional(CONF_SPEED, default=self._speed): SPEED_SCHEMA, - vol.Optional(CONF_PITCH, default=self._pitch): PITCH_SCHEMA, - vol.Optional(CONF_GAIN, default=self._gain): GAIN_SCHEMA, - vol.Optional(CONF_PROFILES, default=self._profiles): PROFILES_SCHEMA, - vol.Optional(CONF_TEXT_TYPE, default=self._text_type): TEXT_TYPE_SCHEMA, - } - ) try: - options = options_schema(options) + options = self._options_schema(options) except vol.Invalid as err: _LOGGER.error("Error: %s when validating options: %s", err, options) return None, None