mirror of
https://github.com/home-assistant/core.git
synced 2025-04-23 08:47:57 +00:00
Refactor the validation in Google Cloud TTS (#120853)
This commit is contained in:
parent
8ff4991f07
commit
17daccd38a
16
homeassistant/components/google_cloud/const.py
Normal file
16
homeassistant/components/google_cloud/const.py
Normal file
@ -0,0 +1,16 @@
|
||||
"""Constants for the Google Cloud component."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
CONF_KEY_FILE = "key_file"
|
||||
|
||||
DEFAULT_LANG = "en-US"
|
||||
|
||||
CONF_GENDER = "gender"
|
||||
CONF_VOICE = "voice"
|
||||
CONF_ENCODING = "encoding"
|
||||
CONF_SPEED = "speed"
|
||||
CONF_PITCH = "pitch"
|
||||
CONF_GAIN = "gain"
|
||||
CONF_PROFILES = "profiles"
|
||||
CONF_TEXT_TYPE = "text_type"
|
@ -2,7 +2,36 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from types import MappingProxyType
|
||||
from typing import Any
|
||||
|
||||
from google.cloud import texttospeech
|
||||
import voluptuous as vol
|
||||
|
||||
from homeassistant.components.tts import CONF_LANG
|
||||
import homeassistant.helpers.config_validation as cv
|
||||
from homeassistant.helpers.selector import (
|
||||
NumberSelector,
|
||||
NumberSelectorConfig,
|
||||
SelectSelector,
|
||||
SelectSelectorConfig,
|
||||
SelectSelectorMode,
|
||||
)
|
||||
|
||||
from .const import (
|
||||
CONF_ENCODING,
|
||||
CONF_GAIN,
|
||||
CONF_GENDER,
|
||||
CONF_KEY_FILE,
|
||||
CONF_PITCH,
|
||||
CONF_PROFILES,
|
||||
CONF_SPEED,
|
||||
CONF_TEXT_TYPE,
|
||||
CONF_VOICE,
|
||||
DEFAULT_LANG,
|
||||
)
|
||||
|
||||
DEFAULT_VOICE = ""
|
||||
|
||||
|
||||
async def async_tts_voices(
|
||||
@ -17,3 +46,106 @@ async def async_tts_voices(
|
||||
voices[language_code] = []
|
||||
voices[language_code].append(voice.name)
|
||||
return voices
|
||||
|
||||
|
||||
def tts_options_schema(
|
||||
config_options: MappingProxyType[str, Any], voices: dict[str, list[str]]
|
||||
):
|
||||
"""Return schema for TTS options with default values from config or constants."""
|
||||
return vol.Schema(
|
||||
{
|
||||
vol.Optional(
|
||||
CONF_GENDER,
|
||||
description={"suggested_value": config_options.get(CONF_GENDER)},
|
||||
default=texttospeech.SsmlVoiceGender.NEUTRAL.name, # type: ignore[attr-defined]
|
||||
): SelectSelector(
|
||||
SelectSelectorConfig(
|
||||
mode=SelectSelectorMode.DROPDOWN,
|
||||
options=list(texttospeech.SsmlVoiceGender.__members__),
|
||||
)
|
||||
),
|
||||
vol.Optional(
|
||||
CONF_VOICE,
|
||||
description={"suggested_value": config_options.get(CONF_VOICE)},
|
||||
default=DEFAULT_VOICE,
|
||||
): SelectSelector(
|
||||
SelectSelectorConfig(
|
||||
mode=SelectSelectorMode.DROPDOWN,
|
||||
options=["", *sum(voices.values(), [])],
|
||||
)
|
||||
),
|
||||
vol.Optional(
|
||||
CONF_ENCODING,
|
||||
description={"suggested_value": config_options.get(CONF_ENCODING)},
|
||||
default=texttospeech.AudioEncoding.MP3.name, # type: ignore[attr-defined]
|
||||
): SelectSelector(
|
||||
SelectSelectorConfig(
|
||||
mode=SelectSelectorMode.DROPDOWN,
|
||||
options=list(texttospeech.AudioEncoding.__members__),
|
||||
)
|
||||
),
|
||||
vol.Optional(
|
||||
CONF_SPEED,
|
||||
description={"suggested_value": config_options.get(CONF_SPEED)},
|
||||
default=1.0,
|
||||
): NumberSelector(NumberSelectorConfig(min=0.25, max=4.0, step=0.01)),
|
||||
vol.Optional(
|
||||
CONF_PITCH,
|
||||
description={"suggested_value": config_options.get(CONF_PITCH)},
|
||||
default=0,
|
||||
): NumberSelector(NumberSelectorConfig(min=-20.0, max=20.0, step=0.1)),
|
||||
vol.Optional(
|
||||
CONF_GAIN,
|
||||
description={"suggested_value": config_options.get(CONF_GAIN)},
|
||||
default=0,
|
||||
): NumberSelector(NumberSelectorConfig(min=-96.0, max=16.0, step=0.1)),
|
||||
vol.Optional(
|
||||
CONF_PROFILES,
|
||||
description={"suggested_value": config_options.get(CONF_PROFILES)},
|
||||
default=[],
|
||||
): SelectSelector(
|
||||
SelectSelectorConfig(
|
||||
mode=SelectSelectorMode.DROPDOWN,
|
||||
options=[
|
||||
# https://cloud.google.com/text-to-speech/docs/audio-profiles
|
||||
"wearable-class-device",
|
||||
"handset-class-device",
|
||||
"headphone-class-device",
|
||||
"small-bluetooth-speaker-class-device",
|
||||
"medium-bluetooth-speaker-class-device",
|
||||
"large-home-entertainment-class-device",
|
||||
"large-automotive-class-device",
|
||||
"telephony-class-application",
|
||||
],
|
||||
multiple=True,
|
||||
sort=False,
|
||||
)
|
||||
),
|
||||
vol.Optional(
|
||||
CONF_TEXT_TYPE,
|
||||
description={"suggested_value": config_options.get(CONF_TEXT_TYPE)},
|
||||
default="text",
|
||||
): SelectSelector(
|
||||
SelectSelectorConfig(
|
||||
mode=SelectSelectorMode.DROPDOWN,
|
||||
options=["text", "ssml"],
|
||||
)
|
||||
),
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def tts_platform_schema():
|
||||
"""Return schema for TTS platform."""
|
||||
return vol.Schema(
|
||||
{
|
||||
vol.Optional(CONF_KEY_FILE): cv.string,
|
||||
vol.Optional(CONF_LANG, default=DEFAULT_LANG): cv.matches_regex(
|
||||
r"[a-z]{2,3}-[A-Z]{2}|"
|
||||
),
|
||||
**tts_options_schema({}, {}).schema,
|
||||
vol.Optional(CONF_VOICE, default=DEFAULT_VOICE): cv.matches_regex(
|
||||
r"[a-z]{2,3}-[A-Z]{2}-.*-[A-Z]|"
|
||||
),
|
||||
}
|
||||
)
|
||||
|
@ -14,92 +14,24 @@ from homeassistant.components.tts import (
|
||||
Voice,
|
||||
)
|
||||
from homeassistant.core import HomeAssistant, callback
|
||||
import homeassistant.helpers.config_validation as cv
|
||||
|
||||
from .helpers import async_tts_voices
|
||||
from .const import (
|
||||
CONF_ENCODING,
|
||||
CONF_GAIN,
|
||||
CONF_GENDER,
|
||||
CONF_KEY_FILE,
|
||||
CONF_PITCH,
|
||||
CONF_PROFILES,
|
||||
CONF_SPEED,
|
||||
CONF_TEXT_TYPE,
|
||||
CONF_VOICE,
|
||||
DEFAULT_LANG,
|
||||
)
|
||||
from .helpers import async_tts_voices, tts_options_schema, tts_platform_schema
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
CONF_KEY_FILE = "key_file"
|
||||
CONF_GENDER = "gender"
|
||||
CONF_VOICE = "voice"
|
||||
CONF_ENCODING = "encoding"
|
||||
CONF_SPEED = "speed"
|
||||
CONF_PITCH = "pitch"
|
||||
CONF_GAIN = "gain"
|
||||
CONF_PROFILES = "profiles"
|
||||
CONF_TEXT_TYPE = "text_type"
|
||||
|
||||
DEFAULT_LANG = "en-US"
|
||||
|
||||
DEFAULT_GENDER = "NEUTRAL"
|
||||
|
||||
LANG_REGEX = r"[a-z]{2,3}-[A-Z]{2}|"
|
||||
VOICE_REGEX = r"[a-z]{2,3}-[A-Z]{2}-.*-[A-Z]|"
|
||||
DEFAULT_VOICE = ""
|
||||
|
||||
DEFAULT_ENCODING = "MP3"
|
||||
|
||||
MIN_SPEED = 0.25
|
||||
MAX_SPEED = 4.0
|
||||
DEFAULT_SPEED = 1.0
|
||||
|
||||
MIN_PITCH = -20.0
|
||||
MAX_PITCH = 20.0
|
||||
DEFAULT_PITCH = 0
|
||||
|
||||
MIN_GAIN = -96.0
|
||||
MAX_GAIN = 16.0
|
||||
DEFAULT_GAIN = 0
|
||||
|
||||
SUPPORTED_TEXT_TYPES = ["text", "ssml"]
|
||||
DEFAULT_TEXT_TYPE = "text"
|
||||
|
||||
SUPPORTED_PROFILES = [
|
||||
"wearable-class-device",
|
||||
"handset-class-device",
|
||||
"headphone-class-device",
|
||||
"small-bluetooth-speaker-class-device",
|
||||
"medium-bluetooth-speaker-class-device",
|
||||
"large-home-entertainment-class-device",
|
||||
"large-automotive-class-device",
|
||||
"telephony-class-application",
|
||||
]
|
||||
|
||||
SUPPORTED_OPTIONS = [
|
||||
CONF_VOICE,
|
||||
CONF_GENDER,
|
||||
CONF_ENCODING,
|
||||
CONF_SPEED,
|
||||
CONF_PITCH,
|
||||
CONF_GAIN,
|
||||
CONF_PROFILES,
|
||||
CONF_TEXT_TYPE,
|
||||
]
|
||||
|
||||
GENDER_SCHEMA = vol.All(vol.Upper, vol.In(texttospeech.SsmlVoiceGender.__members__))
|
||||
VOICE_SCHEMA = cv.matches_regex(VOICE_REGEX)
|
||||
SCHEMA_ENCODING = vol.All(vol.Upper, vol.In(texttospeech.AudioEncoding.__members__))
|
||||
SPEED_SCHEMA = vol.All(vol.Coerce(float), vol.Clamp(min=MIN_SPEED, max=MAX_SPEED))
|
||||
PITCH_SCHEMA = vol.All(vol.Coerce(float), vol.Clamp(min=MIN_PITCH, max=MAX_PITCH))
|
||||
GAIN_SCHEMA = vol.All(vol.Coerce(float), vol.Clamp(min=MIN_GAIN, max=MAX_GAIN))
|
||||
PROFILES_SCHEMA = vol.All(cv.ensure_list, [vol.In(SUPPORTED_PROFILES)])
|
||||
TEXT_TYPE_SCHEMA = vol.All(vol.Lower, vol.In(SUPPORTED_TEXT_TYPES))
|
||||
|
||||
PLATFORM_SCHEMA = TTS_PLATFORM_SCHEMA.extend(
|
||||
{
|
||||
vol.Optional(CONF_KEY_FILE): cv.string,
|
||||
vol.Optional(CONF_LANG, default=DEFAULT_LANG): cv.matches_regex(LANG_REGEX),
|
||||
vol.Optional(CONF_GENDER, default=DEFAULT_GENDER): GENDER_SCHEMA,
|
||||
vol.Optional(CONF_VOICE, default=DEFAULT_VOICE): VOICE_SCHEMA,
|
||||
vol.Optional(CONF_ENCODING, default=DEFAULT_ENCODING): SCHEMA_ENCODING,
|
||||
vol.Optional(CONF_SPEED, default=DEFAULT_SPEED): SPEED_SCHEMA,
|
||||
vol.Optional(CONF_PITCH, default=DEFAULT_PITCH): PITCH_SCHEMA,
|
||||
vol.Optional(CONF_GAIN, default=DEFAULT_GAIN): GAIN_SCHEMA,
|
||||
vol.Optional(CONF_PROFILES, default=[]): PROFILES_SCHEMA,
|
||||
vol.Optional(CONF_TEXT_TYPE, default=DEFAULT_TEXT_TYPE): TEXT_TYPE_SCHEMA,
|
||||
}
|
||||
)
|
||||
PLATFORM_SCHEMA = TTS_PLATFORM_SCHEMA.extend(tts_platform_schema().schema)
|
||||
|
||||
|
||||
async def async_get_engine(hass, config, discovery_info=None):
|
||||
@ -124,15 +56,8 @@ async def async_get_engine(hass, config, discovery_info=None):
|
||||
hass,
|
||||
client,
|
||||
voices,
|
||||
config[CONF_LANG],
|
||||
config[CONF_GENDER],
|
||||
config[CONF_VOICE],
|
||||
config[CONF_ENCODING],
|
||||
config[CONF_SPEED],
|
||||
config[CONF_PITCH],
|
||||
config[CONF_GAIN],
|
||||
config[CONF_PROFILES],
|
||||
config[CONF_TEXT_TYPE],
|
||||
config.get(CONF_LANG, DEFAULT_LANG),
|
||||
tts_options_schema(config, voices),
|
||||
)
|
||||
|
||||
|
||||
@ -144,15 +69,8 @@ class GoogleCloudTTSProvider(Provider):
|
||||
hass: HomeAssistant,
|
||||
client: texttospeech.TextToSpeechAsyncClient,
|
||||
voices: dict[str, list[str]],
|
||||
language=DEFAULT_LANG,
|
||||
gender=DEFAULT_GENDER,
|
||||
voice=DEFAULT_VOICE,
|
||||
encoding=DEFAULT_ENCODING,
|
||||
speed=1.0,
|
||||
pitch=0,
|
||||
gain=0,
|
||||
profiles=None,
|
||||
text_type=DEFAULT_TEXT_TYPE,
|
||||
language,
|
||||
options_schema,
|
||||
) -> None:
|
||||
"""Init Google Cloud TTS service."""
|
||||
self.hass = hass
|
||||
@ -160,14 +78,7 @@ class GoogleCloudTTSProvider(Provider):
|
||||
self._client = client
|
||||
self._voices = voices
|
||||
self._language = language
|
||||
self._gender = gender
|
||||
self._voice = voice
|
||||
self._encoding = encoding
|
||||
self._speed = speed
|
||||
self._pitch = pitch
|
||||
self._gain = gain
|
||||
self._profiles = profiles
|
||||
self._text_type = text_type
|
||||
self._options_schema = options_schema
|
||||
|
||||
@property
|
||||
def supported_languages(self):
|
||||
@ -182,21 +93,12 @@ class GoogleCloudTTSProvider(Provider):
|
||||
@property
|
||||
def supported_options(self):
|
||||
"""Return a list of supported options."""
|
||||
return SUPPORTED_OPTIONS
|
||||
return [option.schema for option in self._options_schema.schema]
|
||||
|
||||
@property
|
||||
def default_options(self):
|
||||
"""Return a dict including default options."""
|
||||
return {
|
||||
CONF_GENDER: self._gender,
|
||||
CONF_VOICE: self._voice,
|
||||
CONF_ENCODING: self._encoding,
|
||||
CONF_SPEED: self._speed,
|
||||
CONF_PITCH: self._pitch,
|
||||
CONF_GAIN: self._gain,
|
||||
CONF_PROFILES: self._profiles,
|
||||
CONF_TEXT_TYPE: self._text_type,
|
||||
}
|
||||
return self._options_schema({})
|
||||
|
||||
@callback
|
||||
def async_get_supported_voices(self, language: str) -> list[Voice] | None:
|
||||
@ -207,20 +109,8 @@ class GoogleCloudTTSProvider(Provider):
|
||||
|
||||
async def async_get_tts_audio(self, message, language, options):
|
||||
"""Load TTS from google."""
|
||||
options_schema = vol.Schema(
|
||||
{
|
||||
vol.Optional(CONF_GENDER, default=self._gender): GENDER_SCHEMA,
|
||||
vol.Optional(CONF_VOICE, default=self._voice): VOICE_SCHEMA,
|
||||
vol.Optional(CONF_ENCODING, default=self._encoding): SCHEMA_ENCODING,
|
||||
vol.Optional(CONF_SPEED, default=self._speed): SPEED_SCHEMA,
|
||||
vol.Optional(CONF_PITCH, default=self._pitch): PITCH_SCHEMA,
|
||||
vol.Optional(CONF_GAIN, default=self._gain): GAIN_SCHEMA,
|
||||
vol.Optional(CONF_PROFILES, default=self._profiles): PROFILES_SCHEMA,
|
||||
vol.Optional(CONF_TEXT_TYPE, default=self._text_type): TEXT_TYPE_SCHEMA,
|
||||
}
|
||||
)
|
||||
try:
|
||||
options = options_schema(options)
|
||||
options = self._options_schema(options)
|
||||
except vol.Invalid as err:
|
||||
_LOGGER.error("Error: %s when validating options: %s", err, options)
|
||||
return None, None
|
||||
|
Loading…
x
Reference in New Issue
Block a user