Add strict type annotations to amazon_polly (#50697)

* add strict type annotations

* apply suggestions

Co-authored-by: Erik Montnemery <erik@montnemery.com>
This commit is contained in:
Michael 2021-05-17 14:09:52 +02:00 committed by GitHub
parent 9e86602950
commit df6862a519
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 226 additions and 141 deletions

View File

@ -44,7 +44,7 @@ omit =
homeassistant/components/alarmdecoder/const.py
homeassistant/components/alarmdecoder/sensor.py
homeassistant/components/alpha_vantage/sensor.py
homeassistant/components/amazon_polly/tts.py
homeassistant/components/amazon_polly/*
homeassistant/components/ambiclimate/climate.py
homeassistant/components/ambient_station/*
homeassistant/components/amcrest/*

View File

@ -8,6 +8,7 @@ homeassistant.components.actiontec.*
homeassistant.components.aftership.*
homeassistant.components.airly.*
homeassistant.components.aladdin_connect.*
homeassistant.components.amazon_polly.*
homeassistant.components.ampio.*
homeassistant.components.automation.*
homeassistant.components.binary_sensor.*

View File

@ -0,0 +1,131 @@
"""Constants for the Amazon Polly text to speech service."""
from __future__ import annotations
from typing import Final
CONF_REGION: Final = "region_name"
CONF_ACCESS_KEY_ID: Final = "aws_access_key_id"
CONF_SECRET_ACCESS_KEY: Final = "aws_secret_access_key"
DEFAULT_REGION: Final = "us-east-1"
SUPPORTED_REGIONS: Final[list[str]] = [
"us-east-1",
"us-east-2",
"us-west-1",
"us-west-2",
"ca-central-1",
"eu-west-1",
"eu-central-1",
"eu-west-2",
"eu-west-3",
"ap-southeast-1",
"ap-southeast-2",
"ap-northeast-2",
"ap-northeast-1",
"ap-south-1",
"sa-east-1",
]
CONF_ENGINE: Final = "engine"
CONF_VOICE: Final = "voice"
CONF_OUTPUT_FORMAT: Final = "output_format"
CONF_SAMPLE_RATE: Final = "sample_rate"
CONF_TEXT_TYPE: Final = "text_type"
SUPPORTED_VOICES: Final[list[str]] = [
"Olivia", # Female, Australian, Neural
"Zhiyu", # Chinese
"Mads",
"Naja", # Danish
"Ruben",
"Lotte", # Dutch
"Russell",
"Nicole", # English Australian
"Brian",
"Amy",
"Emma", # English
"Aditi",
"Raveena", # English, Indian
"Joey",
"Justin",
"Matthew",
"Ivy",
"Joanna",
"Kendra",
"Kimberly",
"Salli", # English
"Geraint", # English Welsh
"Mathieu",
"Celine",
"Lea", # French
"Chantal", # French Canadian
"Hans",
"Marlene",
"Vicki", # German
"Aditi", # Hindi
"Karl",
"Dora", # Icelandic
"Giorgio",
"Carla",
"Bianca", # Italian
"Takumi",
"Mizuki", # Japanese
"Seoyeon", # Korean
"Liv", # Norwegian
"Jacek",
"Jan",
"Ewa",
"Maja", # Polish
"Ricardo",
"Vitoria", # Portuguese, Brazilian
"Cristiano",
"Ines", # Portuguese, European
"Carmen", # Romanian
"Maxim",
"Tatyana", # Russian
"Enrique",
"Conchita",
"Lucia", # Spanish European
"Mia", # Spanish Mexican
"Miguel", # Spanish US
"Penelope", # Spanish US
"Lupe", # Spanish US
"Astrid", # Swedish
"Filiz", # Turkish
"Gwyneth", # Welsh
]
SUPPORTED_OUTPUT_FORMATS: Final[list[str]] = ["mp3", "ogg_vorbis", "pcm"]
SUPPORTED_ENGINES: Final[list[str]] = ["neural", "standard"]
SUPPORTED_SAMPLE_RATES: Final[list[str]] = ["8000", "16000", "22050", "24000"]
SUPPORTED_SAMPLE_RATES_MAP: Final[dict[str, list[str]]] = {
"mp3": ["8000", "16000", "22050", "24000"],
"ogg_vorbis": ["8000", "16000", "22050"],
"pcm": ["8000", "16000"],
}
SUPPORTED_TEXT_TYPES: Final[list[str]] = ["text", "ssml"]
CONTENT_TYPE_EXTENSIONS: Final[dict[str, str]] = {
"audio/mpeg": "mp3",
"audio/ogg": "ogg",
"audio/pcm": "pcm",
}
DEFAULT_ENGINE: Final = "standard"
DEFAULT_VOICE: Final = "Joanna"
DEFAULT_OUTPUT_FORMAT: Final = "mp3"
DEFAULT_TEXT_TYPE: Final = "text"
DEFAULT_SAMPLE_RATES: Final[dict[str, str]] = {
"mp3": "22050",
"ogg_vorbis": "22050",
"pcm": "16000",
}
AWS_CONF_CONNECT_TIMEOUT: Final = 10
AWS_CONF_READ_TIMEOUT: Final = 5
AWS_CONF_MAX_POOL_CONNECTIONS: Final = 1

View File

@ -1,136 +1,54 @@
"""Support for the Amazon Polly text to speech service."""
from __future__ import annotations
import logging
from typing import Final
import boto3
import botocore
import voluptuous as vol
from homeassistant.components.tts import PLATFORM_SCHEMA, Provider
from homeassistant.components.tts import (
PLATFORM_SCHEMA as BASE_PLATFORM_SCHEMA,
Provider,
TtsAudioType,
)
from homeassistant.const import ATTR_CREDENTIALS, CONF_PROFILE_NAME
from homeassistant.core import HomeAssistant
import homeassistant.helpers.config_validation as cv
from homeassistant.helpers.typing import ConfigType, DiscoveryInfoType
_LOGGER = logging.getLogger(__name__)
from .const import (
AWS_CONF_CONNECT_TIMEOUT,
AWS_CONF_MAX_POOL_CONNECTIONS,
AWS_CONF_READ_TIMEOUT,
CONF_ACCESS_KEY_ID,
CONF_ENGINE,
CONF_OUTPUT_FORMAT,
CONF_REGION,
CONF_SAMPLE_RATE,
CONF_SECRET_ACCESS_KEY,
CONF_TEXT_TYPE,
CONF_VOICE,
CONTENT_TYPE_EXTENSIONS,
DEFAULT_ENGINE,
DEFAULT_OUTPUT_FORMAT,
DEFAULT_REGION,
DEFAULT_SAMPLE_RATES,
DEFAULT_TEXT_TYPE,
DEFAULT_VOICE,
SUPPORTED_ENGINES,
SUPPORTED_OUTPUT_FORMATS,
SUPPORTED_REGIONS,
SUPPORTED_SAMPLE_RATES,
SUPPORTED_SAMPLE_RATES_MAP,
SUPPORTED_TEXT_TYPES,
SUPPORTED_VOICES,
)
CONF_REGION = "region_name"
CONF_ACCESS_KEY_ID = "aws_access_key_id"
CONF_SECRET_ACCESS_KEY = "aws_secret_access_key"
_LOGGER: Final = logging.getLogger(__name__)
DEFAULT_REGION = "us-east-1"
SUPPORTED_REGIONS = [
"us-east-1",
"us-east-2",
"us-west-1",
"us-west-2",
"ca-central-1",
"eu-west-1",
"eu-central-1",
"eu-west-2",
"eu-west-3",
"ap-southeast-1",
"ap-southeast-2",
"ap-northeast-2",
"ap-northeast-1",
"ap-south-1",
"sa-east-1",
]
CONF_ENGINE = "engine"
CONF_VOICE = "voice"
CONF_OUTPUT_FORMAT = "output_format"
CONF_SAMPLE_RATE = "sample_rate"
CONF_TEXT_TYPE = "text_type"
SUPPORTED_VOICES = [
"Olivia", # Female, Australian, Neural
"Zhiyu", # Chinese
"Mads",
"Naja", # Danish
"Ruben",
"Lotte", # Dutch
"Russell",
"Nicole", # English Australian
"Brian",
"Amy",
"Emma", # English
"Aditi",
"Raveena", # English, Indian
"Joey",
"Justin",
"Matthew",
"Ivy",
"Joanna",
"Kendra",
"Kimberly",
"Salli", # English
"Geraint", # English Welsh
"Mathieu",
"Celine",
"Lea", # French
"Chantal", # French Canadian
"Hans",
"Marlene",
"Vicki", # German
"Aditi", # Hindi
"Karl",
"Dora", # Icelandic
"Giorgio",
"Carla",
"Bianca", # Italian
"Takumi",
"Mizuki", # Japanese
"Seoyeon", # Korean
"Liv", # Norwegian
"Jacek",
"Jan",
"Ewa",
"Maja", # Polish
"Ricardo",
"Vitoria", # Portuguese, Brazilian
"Cristiano",
"Ines", # Portuguese, European
"Carmen", # Romanian
"Maxim",
"Tatyana", # Russian
"Enrique",
"Conchita",
"Lucia", # Spanish European
"Mia", # Spanish Mexican
"Miguel", # Spanish US
"Penelope", # Spanish US
"Lupe", # Spanish US
"Astrid", # Swedish
"Filiz", # Turkish
"Gwyneth", # Welsh
]
SUPPORTED_OUTPUT_FORMATS = ["mp3", "ogg_vorbis", "pcm"]
SUPPORTED_ENGINES = ["neural", "standard"]
SUPPORTED_SAMPLE_RATES = ["8000", "16000", "22050", "24000"]
SUPPORTED_SAMPLE_RATES_MAP = {
"mp3": ["8000", "16000", "22050", "24000"],
"ogg_vorbis": ["8000", "16000", "22050"],
"pcm": ["8000", "16000"],
}
SUPPORTED_TEXT_TYPES = ["text", "ssml"]
CONTENT_TYPE_EXTENSIONS = {"audio/mpeg": "mp3", "audio/ogg": "ogg", "audio/pcm": "pcm"}
DEFAULT_ENGINE = "standard"
DEFAULT_VOICE = "Joanna"
DEFAULT_OUTPUT_FORMAT = "mp3"
DEFAULT_TEXT_TYPE = "text"
DEFAULT_SAMPLE_RATES = {"mp3": "22050", "ogg_vorbis": "22050", "pcm": "16000"}
AWS_CONF_CONNECT_TIMEOUT = 10
AWS_CONF_READ_TIMEOUT = 5
AWS_CONF_MAX_POOL_CONNECTIONS = 1
PLATFORM_SCHEMA = PLATFORM_SCHEMA.extend(
PLATFORM_SCHEMA: Final = BASE_PLATFORM_SCHEMA.extend(
{
vol.Optional(CONF_REGION, default=DEFAULT_REGION): vol.In(SUPPORTED_REGIONS),
vol.Inclusive(CONF_ACCESS_KEY_ID, ATTR_CREDENTIALS): cv.string,
@ -151,11 +69,15 @@ PLATFORM_SCHEMA = PLATFORM_SCHEMA.extend(
)
def get_engine(hass, config, discovery_info=None):
def get_engine(
hass: HomeAssistant,
config: ConfigType,
discovery_info: DiscoveryInfoType | None = None,
) -> Provider | None:
"""Set up Amazon Polly speech component."""
output_format = config[CONF_OUTPUT_FORMAT]
sample_rate = config.get(CONF_SAMPLE_RATE, DEFAULT_SAMPLE_RATES[output_format])
if sample_rate not in SUPPORTED_SAMPLE_RATES_MAP.get(output_format):
if sample_rate not in SUPPORTED_SAMPLE_RATES_MAP[output_format]:
_LOGGER.error(
"%s is not a valid sample rate for %s", sample_rate, output_format
)
@ -163,7 +85,7 @@ def get_engine(hass, config, discovery_info=None):
config[CONF_SAMPLE_RATE] = sample_rate
profile = config.get(CONF_PROFILE_NAME)
profile: str | None = config.get(CONF_PROFILE_NAME)
if profile is not None:
boto3.setup_default_session(profile_name=profile)
@ -185,16 +107,20 @@ def get_engine(hass, config, discovery_info=None):
polly_client = boto3.client("polly", **aws_config)
supported_languages = []
supported_languages: list[str] = []
all_voices = {}
all_voices: dict[str, dict[str, str]] = {}
all_voices_req = polly_client.describe_voices()
for voice in all_voices_req.get("Voices"):
all_voices[voice.get("Id")] = voice
if voice.get("LanguageCode") not in supported_languages:
supported_languages.append(voice.get("LanguageCode"))
for voice in all_voices_req.get("Voices", []):
voice_id: str | None = voice.get("Id")
if voice_id is None:
continue
all_voices[voice_id] = voice
language_code: str | None = voice.get("LanguageCode")
if language_code is not None and language_code not in supported_languages:
supported_languages.append(language_code)
return AmazonPollyProvider(polly_client, config, supported_languages, all_voices)
@ -202,39 +128,53 @@ def get_engine(hass, config, discovery_info=None):
class AmazonPollyProvider(Provider):
"""Amazon Polly speech api provider."""
def __init__(self, polly_client, config, supported_languages, all_voices):
def __init__(
self,
polly_client: boto3.client,
config: ConfigType,
supported_languages: list[str],
all_voices: dict[str, dict[str, str]],
) -> None:
"""Initialize Amazon Polly provider for TTS."""
self.client = polly_client
self.config = config
self.supported_langs = supported_languages
self.all_voices = all_voices
self.default_voice = self.config[CONF_VOICE]
self.default_voice: str = self.config[CONF_VOICE]
self.name = "Amazon Polly"
@property
def supported_languages(self):
def supported_languages(self) -> list[str]:
"""Return a list of supported languages."""
return self.supported_langs
@property
def default_language(self):
def default_language(self) -> str | None:
"""Return the default language."""
return self.all_voices.get(self.default_voice).get("LanguageCode")
return self.all_voices.get(self.default_voice, {}).get("LanguageCode")
@property
def default_options(self):
def default_options(self) -> dict[str, str]:
"""Return dict include default options."""
return {CONF_VOICE: self.default_voice}
@property
def supported_options(self):
def supported_options(self) -> list[str]:
"""Return a list of supported options."""
return [CONF_VOICE]
def get_tts_audio(self, message, language=None, options=None):
def get_tts_audio(
self,
message: str,
language: str | None = None,
options: dict[str, str] | None = None,
) -> TtsAudioType:
"""Request TTS file from Polly."""
if options is None or language is None:
_LOGGER.debug("language and/or options were missing")
return None, None
voice_id = options.get(CONF_VOICE, self.default_voice)
voice_in_dict = self.all_voices.get(voice_id)
voice_in_dict = self.all_voices[voice_id]
if language != voice_in_dict.get("LanguageCode"):
_LOGGER.error("%s does not support the %s language", voice_id, language)
return None, None

View File

@ -9,7 +9,7 @@ import logging
import mimetypes
import os
import re
from typing import cast
from typing import Optional, Tuple, cast
from aiohttp import web
import mutagen
@ -47,6 +47,8 @@ from homeassistant.util.yaml import load_yaml
_LOGGER = logging.getLogger(__name__)
TtsAudioType = Tuple[Optional[str], Optional[bytes]]
ATTR_CACHE = "cache"
ATTR_LANGUAGE = "language"
ATTR_MESSAGE = "message"

View File

@ -99,6 +99,17 @@ no_implicit_optional = true
warn_return_any = true
warn_unreachable = true
[mypy-homeassistant.components.amazon_polly.*]
check_untyped_defs = true
disallow_incomplete_defs = true
disallow_subclassing_any = true
disallow_untyped_calls = true
disallow_untyped_decorators = true
disallow_untyped_defs = true
no_implicit_optional = true
warn_return_any = true
warn_unreachable = true
[mypy-homeassistant.components.ampio.*]
check_untyped_defs = true
disallow_incomplete_defs = true