Add strict type annotations to amazon_polly (#50697)

* add strict type annotations

* apply suggestions

Co-authored-by: Erik Montnemery <erik@montnemery.com>
This commit is contained in:
Michael 2021-05-17 14:09:52 +02:00 committed by GitHub
parent 9e86602950
commit df6862a519
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 226 additions and 141 deletions

View File

@ -44,7 +44,7 @@ omit =
homeassistant/components/alarmdecoder/const.py homeassistant/components/alarmdecoder/const.py
homeassistant/components/alarmdecoder/sensor.py homeassistant/components/alarmdecoder/sensor.py
homeassistant/components/alpha_vantage/sensor.py homeassistant/components/alpha_vantage/sensor.py
homeassistant/components/amazon_polly/tts.py homeassistant/components/amazon_polly/*
homeassistant/components/ambiclimate/climate.py homeassistant/components/ambiclimate/climate.py
homeassistant/components/ambient_station/* homeassistant/components/ambient_station/*
homeassistant/components/amcrest/* homeassistant/components/amcrest/*

View File

@ -8,6 +8,7 @@ homeassistant.components.actiontec.*
homeassistant.components.aftership.* homeassistant.components.aftership.*
homeassistant.components.airly.* homeassistant.components.airly.*
homeassistant.components.aladdin_connect.* homeassistant.components.aladdin_connect.*
homeassistant.components.amazon_polly.*
homeassistant.components.ampio.* homeassistant.components.ampio.*
homeassistant.components.automation.* homeassistant.components.automation.*
homeassistant.components.binary_sensor.* homeassistant.components.binary_sensor.*

View File

@ -0,0 +1,131 @@
"""Constants for the Amazon Polly text to speech service."""
from __future__ import annotations
from typing import Final
CONF_REGION: Final = "region_name"
CONF_ACCESS_KEY_ID: Final = "aws_access_key_id"
CONF_SECRET_ACCESS_KEY: Final = "aws_secret_access_key"
DEFAULT_REGION: Final = "us-east-1"
SUPPORTED_REGIONS: Final[list[str]] = [
"us-east-1",
"us-east-2",
"us-west-1",
"us-west-2",
"ca-central-1",
"eu-west-1",
"eu-central-1",
"eu-west-2",
"eu-west-3",
"ap-southeast-1",
"ap-southeast-2",
"ap-northeast-2",
"ap-northeast-1",
"ap-south-1",
"sa-east-1",
]
CONF_ENGINE: Final = "engine"
CONF_VOICE: Final = "voice"
CONF_OUTPUT_FORMAT: Final = "output_format"
CONF_SAMPLE_RATE: Final = "sample_rate"
CONF_TEXT_TYPE: Final = "text_type"
SUPPORTED_VOICES: Final[list[str]] = [
"Olivia", # Female, Australian, Neural
"Zhiyu", # Chinese
"Mads",
"Naja", # Danish
"Ruben",
"Lotte", # Dutch
"Russell",
"Nicole", # English Australian
"Brian",
"Amy",
"Emma", # English
"Aditi",
"Raveena", # English, Indian
"Joey",
"Justin",
"Matthew",
"Ivy",
"Joanna",
"Kendra",
"Kimberly",
"Salli", # English
"Geraint", # English Welsh
"Mathieu",
"Celine",
"Lea", # French
"Chantal", # French Canadian
"Hans",
"Marlene",
"Vicki", # German
"Aditi", # Hindi
"Karl",
"Dora", # Icelandic
"Giorgio",
"Carla",
"Bianca", # Italian
"Takumi",
"Mizuki", # Japanese
"Seoyeon", # Korean
"Liv", # Norwegian
"Jacek",
"Jan",
"Ewa",
"Maja", # Polish
"Ricardo",
"Vitoria", # Portuguese, Brazilian
"Cristiano",
"Ines", # Portuguese, European
"Carmen", # Romanian
"Maxim",
"Tatyana", # Russian
"Enrique",
"Conchita",
"Lucia", # Spanish European
"Mia", # Spanish Mexican
"Miguel", # Spanish US
"Penelope", # Spanish US
"Lupe", # Spanish US
"Astrid", # Swedish
"Filiz", # Turkish
"Gwyneth", # Welsh
]
SUPPORTED_OUTPUT_FORMATS: Final[list[str]] = ["mp3", "ogg_vorbis", "pcm"]
SUPPORTED_ENGINES: Final[list[str]] = ["neural", "standard"]
SUPPORTED_SAMPLE_RATES: Final[list[str]] = ["8000", "16000", "22050", "24000"]
SUPPORTED_SAMPLE_RATES_MAP: Final[dict[str, list[str]]] = {
"mp3": ["8000", "16000", "22050", "24000"],
"ogg_vorbis": ["8000", "16000", "22050"],
"pcm": ["8000", "16000"],
}
SUPPORTED_TEXT_TYPES: Final[list[str]] = ["text", "ssml"]
CONTENT_TYPE_EXTENSIONS: Final[dict[str, str]] = {
"audio/mpeg": "mp3",
"audio/ogg": "ogg",
"audio/pcm": "pcm",
}
DEFAULT_ENGINE: Final = "standard"
DEFAULT_VOICE: Final = "Joanna"
DEFAULT_OUTPUT_FORMAT: Final = "mp3"
DEFAULT_TEXT_TYPE: Final = "text"
DEFAULT_SAMPLE_RATES: Final[dict[str, str]] = {
"mp3": "22050",
"ogg_vorbis": "22050",
"pcm": "16000",
}
AWS_CONF_CONNECT_TIMEOUT: Final = 10
AWS_CONF_READ_TIMEOUT: Final = 5
AWS_CONF_MAX_POOL_CONNECTIONS: Final = 1

View File

@ -1,136 +1,54 @@
"""Support for the Amazon Polly text to speech service.""" """Support for the Amazon Polly text to speech service."""
from __future__ import annotations
import logging import logging
from typing import Final
import boto3 import boto3
import botocore import botocore
import voluptuous as vol import voluptuous as vol
from homeassistant.components.tts import PLATFORM_SCHEMA, Provider from homeassistant.components.tts import (
PLATFORM_SCHEMA as BASE_PLATFORM_SCHEMA,
Provider,
TtsAudioType,
)
from homeassistant.const import ATTR_CREDENTIALS, CONF_PROFILE_NAME from homeassistant.const import ATTR_CREDENTIALS, CONF_PROFILE_NAME
from homeassistant.core import HomeAssistant
import homeassistant.helpers.config_validation as cv import homeassistant.helpers.config_validation as cv
from homeassistant.helpers.typing import ConfigType, DiscoveryInfoType
_LOGGER = logging.getLogger(__name__) from .const import (
AWS_CONF_CONNECT_TIMEOUT,
AWS_CONF_MAX_POOL_CONNECTIONS,
AWS_CONF_READ_TIMEOUT,
CONF_ACCESS_KEY_ID,
CONF_ENGINE,
CONF_OUTPUT_FORMAT,
CONF_REGION,
CONF_SAMPLE_RATE,
CONF_SECRET_ACCESS_KEY,
CONF_TEXT_TYPE,
CONF_VOICE,
CONTENT_TYPE_EXTENSIONS,
DEFAULT_ENGINE,
DEFAULT_OUTPUT_FORMAT,
DEFAULT_REGION,
DEFAULT_SAMPLE_RATES,
DEFAULT_TEXT_TYPE,
DEFAULT_VOICE,
SUPPORTED_ENGINES,
SUPPORTED_OUTPUT_FORMATS,
SUPPORTED_REGIONS,
SUPPORTED_SAMPLE_RATES,
SUPPORTED_SAMPLE_RATES_MAP,
SUPPORTED_TEXT_TYPES,
SUPPORTED_VOICES,
)
CONF_REGION = "region_name" _LOGGER: Final = logging.getLogger(__name__)
CONF_ACCESS_KEY_ID = "aws_access_key_id"
CONF_SECRET_ACCESS_KEY = "aws_secret_access_key"
DEFAULT_REGION = "us-east-1" PLATFORM_SCHEMA: Final = BASE_PLATFORM_SCHEMA.extend(
SUPPORTED_REGIONS = [
"us-east-1",
"us-east-2",
"us-west-1",
"us-west-2",
"ca-central-1",
"eu-west-1",
"eu-central-1",
"eu-west-2",
"eu-west-3",
"ap-southeast-1",
"ap-southeast-2",
"ap-northeast-2",
"ap-northeast-1",
"ap-south-1",
"sa-east-1",
]
CONF_ENGINE = "engine"
CONF_VOICE = "voice"
CONF_OUTPUT_FORMAT = "output_format"
CONF_SAMPLE_RATE = "sample_rate"
CONF_TEXT_TYPE = "text_type"
SUPPORTED_VOICES = [
"Olivia", # Female, Australian, Neural
"Zhiyu", # Chinese
"Mads",
"Naja", # Danish
"Ruben",
"Lotte", # Dutch
"Russell",
"Nicole", # English Australian
"Brian",
"Amy",
"Emma", # English
"Aditi",
"Raveena", # English, Indian
"Joey",
"Justin",
"Matthew",
"Ivy",
"Joanna",
"Kendra",
"Kimberly",
"Salli", # English
"Geraint", # English Welsh
"Mathieu",
"Celine",
"Lea", # French
"Chantal", # French Canadian
"Hans",
"Marlene",
"Vicki", # German
"Aditi", # Hindi
"Karl",
"Dora", # Icelandic
"Giorgio",
"Carla",
"Bianca", # Italian
"Takumi",
"Mizuki", # Japanese
"Seoyeon", # Korean
"Liv", # Norwegian
"Jacek",
"Jan",
"Ewa",
"Maja", # Polish
"Ricardo",
"Vitoria", # Portuguese, Brazilian
"Cristiano",
"Ines", # Portuguese, European
"Carmen", # Romanian
"Maxim",
"Tatyana", # Russian
"Enrique",
"Conchita",
"Lucia", # Spanish European
"Mia", # Spanish Mexican
"Miguel", # Spanish US
"Penelope", # Spanish US
"Lupe", # Spanish US
"Astrid", # Swedish
"Filiz", # Turkish
"Gwyneth", # Welsh
]
SUPPORTED_OUTPUT_FORMATS = ["mp3", "ogg_vorbis", "pcm"]
SUPPORTED_ENGINES = ["neural", "standard"]
SUPPORTED_SAMPLE_RATES = ["8000", "16000", "22050", "24000"]
SUPPORTED_SAMPLE_RATES_MAP = {
"mp3": ["8000", "16000", "22050", "24000"],
"ogg_vorbis": ["8000", "16000", "22050"],
"pcm": ["8000", "16000"],
}
SUPPORTED_TEXT_TYPES = ["text", "ssml"]
CONTENT_TYPE_EXTENSIONS = {"audio/mpeg": "mp3", "audio/ogg": "ogg", "audio/pcm": "pcm"}
DEFAULT_ENGINE = "standard"
DEFAULT_VOICE = "Joanna"
DEFAULT_OUTPUT_FORMAT = "mp3"
DEFAULT_TEXT_TYPE = "text"
DEFAULT_SAMPLE_RATES = {"mp3": "22050", "ogg_vorbis": "22050", "pcm": "16000"}
AWS_CONF_CONNECT_TIMEOUT = 10
AWS_CONF_READ_TIMEOUT = 5
AWS_CONF_MAX_POOL_CONNECTIONS = 1
PLATFORM_SCHEMA = PLATFORM_SCHEMA.extend(
{ {
vol.Optional(CONF_REGION, default=DEFAULT_REGION): vol.In(SUPPORTED_REGIONS), vol.Optional(CONF_REGION, default=DEFAULT_REGION): vol.In(SUPPORTED_REGIONS),
vol.Inclusive(CONF_ACCESS_KEY_ID, ATTR_CREDENTIALS): cv.string, vol.Inclusive(CONF_ACCESS_KEY_ID, ATTR_CREDENTIALS): cv.string,
@ -151,11 +69,15 @@ PLATFORM_SCHEMA = PLATFORM_SCHEMA.extend(
) )
def get_engine(hass, config, discovery_info=None): def get_engine(
hass: HomeAssistant,
config: ConfigType,
discovery_info: DiscoveryInfoType | None = None,
) -> Provider | None:
"""Set up Amazon Polly speech component.""" """Set up Amazon Polly speech component."""
output_format = config[CONF_OUTPUT_FORMAT] output_format = config[CONF_OUTPUT_FORMAT]
sample_rate = config.get(CONF_SAMPLE_RATE, DEFAULT_SAMPLE_RATES[output_format]) sample_rate = config.get(CONF_SAMPLE_RATE, DEFAULT_SAMPLE_RATES[output_format])
if sample_rate not in SUPPORTED_SAMPLE_RATES_MAP.get(output_format): if sample_rate not in SUPPORTED_SAMPLE_RATES_MAP[output_format]:
_LOGGER.error( _LOGGER.error(
"%s is not a valid sample rate for %s", sample_rate, output_format "%s is not a valid sample rate for %s", sample_rate, output_format
) )
@ -163,7 +85,7 @@ def get_engine(hass, config, discovery_info=None):
config[CONF_SAMPLE_RATE] = sample_rate config[CONF_SAMPLE_RATE] = sample_rate
profile = config.get(CONF_PROFILE_NAME) profile: str | None = config.get(CONF_PROFILE_NAME)
if profile is not None: if profile is not None:
boto3.setup_default_session(profile_name=profile) boto3.setup_default_session(profile_name=profile)
@ -185,16 +107,20 @@ def get_engine(hass, config, discovery_info=None):
polly_client = boto3.client("polly", **aws_config) polly_client = boto3.client("polly", **aws_config)
supported_languages = [] supported_languages: list[str] = []
all_voices = {} all_voices: dict[str, dict[str, str]] = {}
all_voices_req = polly_client.describe_voices() all_voices_req = polly_client.describe_voices()
for voice in all_voices_req.get("Voices"): for voice in all_voices_req.get("Voices", []):
all_voices[voice.get("Id")] = voice voice_id: str | None = voice.get("Id")
if voice.get("LanguageCode") not in supported_languages: if voice_id is None:
supported_languages.append(voice.get("LanguageCode")) continue
all_voices[voice_id] = voice
language_code: str | None = voice.get("LanguageCode")
if language_code is not None and language_code not in supported_languages:
supported_languages.append(language_code)
return AmazonPollyProvider(polly_client, config, supported_languages, all_voices) return AmazonPollyProvider(polly_client, config, supported_languages, all_voices)
@ -202,39 +128,53 @@ def get_engine(hass, config, discovery_info=None):
class AmazonPollyProvider(Provider): class AmazonPollyProvider(Provider):
"""Amazon Polly speech api provider.""" """Amazon Polly speech api provider."""
def __init__(self, polly_client, config, supported_languages, all_voices): def __init__(
self,
polly_client: boto3.client,
config: ConfigType,
supported_languages: list[str],
all_voices: dict[str, dict[str, str]],
) -> None:
"""Initialize Amazon Polly provider for TTS.""" """Initialize Amazon Polly provider for TTS."""
self.client = polly_client self.client = polly_client
self.config = config self.config = config
self.supported_langs = supported_languages self.supported_langs = supported_languages
self.all_voices = all_voices self.all_voices = all_voices
self.default_voice = self.config[CONF_VOICE] self.default_voice: str = self.config[CONF_VOICE]
self.name = "Amazon Polly" self.name = "Amazon Polly"
@property @property
def supported_languages(self): def supported_languages(self) -> list[str]:
"""Return a list of supported languages.""" """Return a list of supported languages."""
return self.supported_langs return self.supported_langs
@property @property
def default_language(self): def default_language(self) -> str | None:
"""Return the default language.""" """Return the default language."""
return self.all_voices.get(self.default_voice).get("LanguageCode") return self.all_voices.get(self.default_voice, {}).get("LanguageCode")
@property @property
def default_options(self): def default_options(self) -> dict[str, str]:
"""Return dict include default options.""" """Return dict include default options."""
return {CONF_VOICE: self.default_voice} return {CONF_VOICE: self.default_voice}
@property @property
def supported_options(self): def supported_options(self) -> list[str]:
"""Return a list of supported options.""" """Return a list of supported options."""
return [CONF_VOICE] return [CONF_VOICE]
def get_tts_audio(self, message, language=None, options=None): def get_tts_audio(
self,
message: str,
language: str | None = None,
options: dict[str, str] | None = None,
) -> TtsAudioType:
"""Request TTS file from Polly.""" """Request TTS file from Polly."""
if options is None or language is None:
_LOGGER.debug("language and/or options were missing")
return None, None
voice_id = options.get(CONF_VOICE, self.default_voice) voice_id = options.get(CONF_VOICE, self.default_voice)
voice_in_dict = self.all_voices.get(voice_id) voice_in_dict = self.all_voices[voice_id]
if language != voice_in_dict.get("LanguageCode"): if language != voice_in_dict.get("LanguageCode"):
_LOGGER.error("%s does not support the %s language", voice_id, language) _LOGGER.error("%s does not support the %s language", voice_id, language)
return None, None return None, None

View File

@ -9,7 +9,7 @@ import logging
import mimetypes import mimetypes
import os import os
import re import re
from typing import cast from typing import Optional, Tuple, cast
from aiohttp import web from aiohttp import web
import mutagen import mutagen
@ -47,6 +47,8 @@ from homeassistant.util.yaml import load_yaml
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
TtsAudioType = Tuple[Optional[str], Optional[bytes]]
ATTR_CACHE = "cache" ATTR_CACHE = "cache"
ATTR_LANGUAGE = "language" ATTR_LANGUAGE = "language"
ATTR_MESSAGE = "message" ATTR_MESSAGE = "message"

View File

@ -99,6 +99,17 @@ no_implicit_optional = true
warn_return_any = true warn_return_any = true
warn_unreachable = true warn_unreachable = true
[mypy-homeassistant.components.amazon_polly.*]
check_untyped_defs = true
disallow_incomplete_defs = true
disallow_subclassing_any = true
disallow_untyped_calls = true
disallow_untyped_decorators = true
disallow_untyped_defs = true
no_implicit_optional = true
warn_return_any = true
warn_unreachable = true
[mypy-homeassistant.components.ampio.*] [mypy-homeassistant.components.ampio.*]
check_untyped_defs = true check_untyped_defs = true
disallow_incomplete_defs = true disallow_incomplete_defs = true