Minor adjustment of tts typing (#93450)

This commit is contained in:
Erik Montnemery 2023-05-24 21:02:55 +02:00 committed by GitHub
parent 68379dd55a
commit 30d9d7d905
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
20 changed files with 49 additions and 62 deletions

View File

@ -167,12 +167,9 @@ class AmazonPollyProvider(Provider):
self, self,
message: str, message: str,
language: str, language: str,
options: dict[str, Any] | None = None, options: dict[str, Any],
) -> TtsAudioType: ) -> TtsAudioType:
"""Request TTS file from Polly.""" """Request TTS file from Polly."""
if options is None or language is None:
_LOGGER.debug("language and/or options were missing")
return None, None
voice_id = options.get(CONF_VOICE, self.default_voice) voice_id = options.get(CONF_VOICE, self.default_voice)
voice_in_dict = self.all_voices[voice_id] voice_in_dict = self.all_voices[voice_id]
if language != voice_in_dict.get("LanguageCode"): if language != voice_in_dict.get("LanguageCode"):

View File

@ -104,7 +104,7 @@ class BaiduTTSProvider(Provider):
"""Return a list of supported options.""" """Return a list of supported options."""
return SUPPORTED_OPTIONS return SUPPORTED_OPTIONS
def get_tts_audio(self, message, language, options=None): def get_tts_audio(self, message, language, options):
"""Load TTS from BaiduTTS.""" """Load TTS from BaiduTTS."""
aip_speech = AipSpeech( aip_speech = AipSpeech(
@ -113,14 +113,11 @@ class BaiduTTSProvider(Provider):
self._app_data["secretkey"], self._app_data["secretkey"],
) )
if options is None: speech_data = self._speech_conf_data.copy()
result = aip_speech.synthesis(message, language, 1, self._speech_conf_data) for key, value in options.items():
else: speech_data[_OPTIONS[key]] = value
speech_data = self._speech_conf_data.copy()
for key, value in options.items():
speech_data[_OPTIONS[key]] = value
result = aip_speech.synthesis(message, language, 1, speech_data) result = aip_speech.synthesis(message, language, 1, speech_data)
if isinstance(result, dict): if isinstance(result, dict):
_LOGGER.error( _LOGGER.error(

View File

@ -134,12 +134,11 @@ class CloudProvider(Provider):
} }
async def async_get_tts_audio( async def async_get_tts_audio(
self, message: str, language: str, options: dict[str, Any] | None = None self, message: str, language: str, options: dict[str, Any]
) -> TtsAudioType: ) -> TtsAudioType:
"""Load TTS from NabuCasa Cloud.""" """Load TTS from NabuCasa Cloud."""
# Process TTS # Process TTS
try: try:
assert options is not None
data = await self.cloud.voice.process_tts( data = await self.cloud.voice.process_tts(
text=message, text=message,
language=language, language=language,

View File

@ -57,7 +57,7 @@ class DemoProvider(Provider):
return ["voice", "age"] return ["voice", "age"]
def get_tts_audio( def get_tts_audio(
self, message: str, language: str, options: dict[str, Any] | None = None self, message: str, language: str, options: dict[str, Any]
) -> TtsAudioType: ) -> TtsAudioType:
"""Load TTS from demo.""" """Load TTS from demo."""
filename = os.path.join(os.path.dirname(__file__), "tts.mp3") filename = os.path.join(os.path.dirname(__file__), "tts.mp3")

View File

@ -241,7 +241,7 @@ class GoogleCloudTTSProvider(Provider):
CONF_TEXT_TYPE: self._text_type, CONF_TEXT_TYPE: self._text_type,
} }
async def async_get_tts_audio(self, message, language, options=None): async def async_get_tts_audio(self, message, language, options):
"""Load TTS from google.""" """Load TTS from google."""
options_schema = vol.Schema( options_schema = vol.Schema(
{ {

View File

@ -59,13 +59,13 @@ class GoogleProvider(Provider):
"""Return a list of supported options.""" """Return a list of supported options."""
return SUPPORT_OPTIONS return SUPPORT_OPTIONS
def get_tts_audio(self, message, language, options=None): def get_tts_audio(self, message, language, options):
"""Load TTS from google.""" """Load TTS from google."""
tld = self._tld tld = self._tld
if language in MAP_LANG_TLD: if language in MAP_LANG_TLD:
tld = MAP_LANG_TLD[language].tld tld = MAP_LANG_TLD[language].tld
language = MAP_LANG_TLD[language].lang language = MAP_LANG_TLD[language].lang
if options is not None and "tld" in options: if "tld" in options:
tld = options["tld"] tld = options["tld"]
tts = gTTS(text=message, lang=language, tld=tld) tts = gTTS(text=message, lang=language, tld=tld)
mp3_data = BytesIO() mp3_data = BytesIO()

View File

@ -80,7 +80,7 @@ class MaryTTSProvider(Provider):
"""Return a list of supported options.""" """Return a list of supported options."""
return SUPPORT_OPTIONS return SUPPORT_OPTIONS
def get_tts_audio(self, message, language, options=None): def get_tts_audio(self, message, language, options):
"""Load TTS from MaryTTS.""" """Load TTS from MaryTTS."""
effects = options[CONF_EFFECT] effects = options[CONF_EFFECT]

View File

@ -176,7 +176,7 @@ class MicrosoftProvider(Provider):
"""Return a dict include default options.""" """Return a dict include default options."""
return {CONF_GENDER: self._gender, CONF_TYPE: self._type} return {CONF_GENDER: self._gender, CONF_TYPE: self._type}
def get_tts_audio(self, message, language, options=None): def get_tts_audio(self, message, language, options):
"""Load TTS from Microsoft.""" """Load TTS from Microsoft."""
if language is None: if language is None:
language = self._lang language = self._lang

View File

@ -46,7 +46,7 @@ class PicoProvider(Provider):
"""Return list of supported languages.""" """Return list of supported languages."""
return SUPPORT_LANGUAGES return SUPPORT_LANGUAGES
def get_tts_audio(self, message, language, options=None): def get_tts_audio(self, message, language, options):
"""Load TTS using pico2wave.""" """Load TTS using pico2wave."""
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmpf: with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmpf:
fname = tmpf.name fname = tmpf.name

View File

@ -364,7 +364,7 @@ class TextToSpeechEntity(RestoreEntity):
@final @final
async def internal_async_get_tts_audio( async def internal_async_get_tts_audio(
self, message: str, language: str, options: dict[str, Any] | None = None self, message: str, language: str, options: dict[str, Any]
) -> TtsAudioType: ) -> TtsAudioType:
"""Process an audio stream to TTS service. """Process an audio stream to TTS service.
@ -377,13 +377,13 @@ class TextToSpeechEntity(RestoreEntity):
) )
def get_tts_audio( def get_tts_audio(
self, message: str, language: str, options: dict[str, Any] | None = None self, message: str, language: str, options: dict[str, Any]
) -> TtsAudioType: ) -> TtsAudioType:
"""Load tts audio file from the engine.""" """Load tts audio file from the engine."""
raise NotImplementedError() raise NotImplementedError()
async def async_get_tts_audio( async def async_get_tts_audio(
self, message: str, language: str, options: dict[str, Any] | None = None self, message: str, language: str, options: dict[str, Any]
) -> TtsAudioType: ) -> TtsAudioType:
"""Load tts audio file from the engine. """Load tts audio file from the engine.
@ -478,9 +478,9 @@ class SpeechManager:
def process_options( def process_options(
self, self,
engine_instance: TextToSpeechEntity | Provider, engine_instance: TextToSpeechEntity | Provider,
language: str | None = None, language: str | None,
options: dict | None = None, options: dict | None,
) -> tuple[str, dict | None]: ) -> tuple[str, dict[str, Any]]:
"""Validate and process options.""" """Validate and process options."""
# Languages # Languages
language = language or engine_instance.default_language language = language or engine_instance.default_language
@ -491,23 +491,18 @@ class SpeechManager:
): ):
raise HomeAssistantError(f"Language '{language}' not supported") raise HomeAssistantError(f"Language '{language}' not supported")
# Options # Update default options with provided options
if (default_options := engine_instance.default_options) and options: merged_options = dict(engine_instance.default_options or {})
merged_options = dict(default_options) merged_options.update(options or {})
merged_options.update(options)
options = merged_options
if not options:
options = None if default_options is None else dict(default_options)
if options is not None: supported_options = engine_instance.supported_options or []
supported_options = engine_instance.supported_options or [] invalid_opts = [
invalid_opts = [ opt_name for opt_name in merged_options if opt_name not in supported_options
opt_name for opt_name in options if opt_name not in supported_options ]
] if invalid_opts:
if invalid_opts: raise HomeAssistantError(f"Invalid options found: {invalid_opts}")
raise HomeAssistantError(f"Invalid options found: {invalid_opts}")
return language, options return language, merged_options
async def async_get_url_path( async def async_get_url_path(
self, self,
@ -602,7 +597,7 @@ class SpeechManager:
message: str, message: str,
cache: bool, cache: bool,
language: str, language: str,
options: dict | None, options: dict[str, Any],
) -> str: ) -> str:
"""Receive TTS, store for view in cache and return filename. """Receive TTS, store for view in cache and return filename.

View File

@ -240,13 +240,13 @@ class Provider:
return None return None
def get_tts_audio( def get_tts_audio(
self, message: str, language: str, options: dict[str, Any] | None = None self, message: str, language: str, options: dict[str, Any]
) -> TtsAudioType: ) -> TtsAudioType:
"""Load tts audio file from provider.""" """Load tts audio file from provider."""
raise NotImplementedError() raise NotImplementedError()
async def async_get_tts_audio( async def async_get_tts_audio(
self, message: str, language: str, options: dict[str, Any] | None = None self, message: str, language: str, options: dict[str, Any]
) -> TtsAudioType: ) -> TtsAudioType:
"""Load tts audio file from provider. """Load tts audio file from provider.

View File

@ -187,7 +187,7 @@ class VoiceRSSProvider(Provider):
"""Return list of supported languages.""" """Return list of supported languages."""
return SUPPORT_LANGUAGES return SUPPORT_LANGUAGES
async def async_get_tts_audio(self, message, language, options=None): async def async_get_tts_audio(self, message, language, options):
"""Load TTS from VoiceRSS.""" """Load TTS from VoiceRSS."""
websession = async_get_clientsession(self.hass) websession = async_get_clientsession(self.hass)
form_data = self._form_data.copy() form_data = self._form_data.copy()

View File

@ -180,7 +180,7 @@ class WatsonTTSProvider(Provider):
"""Return a list of supported options.""" """Return a list of supported options."""
return [CONF_VOICE] return [CONF_VOICE]
def get_tts_audio(self, message, language=None, options=None): def get_tts_audio(self, message, language, options):
"""Request TTS file from Watson TTS.""" """Request TTS file from Watson TTS."""
response = self.service.synthesize( response = self.service.synthesize(
text=message, accept=self.output_format, voice=options[CONF_VOICE] text=message, accept=self.output_format, voice=options[CONF_VOICE]

View File

@ -94,7 +94,7 @@ class WyomingTtsProvider(tts.TextToSpeechEntity):
"""Return a list of supported voices for a language.""" """Return a list of supported voices for a language."""
return self._voices.get(language) return self._voices.get(language)
async def async_get_tts_audio(self, message, language, options=None): async def async_get_tts_audio(self, message, language, options):
"""Load TTS from UNIX socket.""" """Load TTS from UNIX socket."""
try: try:
async with AsyncTcpClient(self.service.host, self.service.port) as client: async with AsyncTcpClient(self.service.host, self.service.port) as client:
@ -129,7 +129,7 @@ class WyomingTtsProvider(tts.TextToSpeechEntity):
except (OSError, WyomingError): except (OSError, WyomingError):
return (None, None) return (None, None)
if (options is None) or (options[tts.ATTR_AUDIO_OUTPUT] == "wav"): if options[tts.ATTR_AUDIO_OUTPUT] == "wav":
return ("wav", data) return ("wav", data)
# Raw output (convert to 16Khz, 16-bit mono) # Raw output (convert to 16Khz, 16-bit mono)

View File

@ -114,11 +114,10 @@ class YandexSpeechKitProvider(Provider):
"""Return list of supported options.""" """Return list of supported options."""
return SUPPORTED_OPTIONS return SUPPORTED_OPTIONS
async def async_get_tts_audio(self, message, language, options=None): async def async_get_tts_audio(self, message, language, options):
"""Load TTS from yandex.""" """Load TTS from yandex."""
websession = async_get_clientsession(self.hass) websession = async_get_clientsession(self.hass)
actual_language = language actual_language = language
options = options or {}
try: try:
async with async_timeout.timeout(10): async with async_timeout.timeout(10):

View File

@ -2449,7 +2449,7 @@ _INHERITANCE_MATCH: dict[str, list[ClassTypeHintMatch]] = {
), ),
TypeHintMatch( TypeHintMatch(
function_name="get_tts_audio", function_name="get_tts_audio",
arg_types={1: "str", 2: "str", 3: "dict[str, Any] | None"}, arg_types={1: "str", 2: "str", 3: "dict[str, Any]"},
return_type="TtsAudioType", return_type="TtsAudioType",
has_async_counterpart=True, has_async_counterpart=True,
), ),

View File

@ -127,7 +127,7 @@ class MockTTSProvider(tts.Provider):
return ["voice", "age", tts.ATTR_AUDIO_OUTPUT] return ["voice", "age", tts.ATTR_AUDIO_OUTPUT]
def get_tts_audio( def get_tts_audio(
self, message: str, language: str, options: dict[str, Any] | None = None self, message: str, language: str, options: dict[str, Any]
) -> tts.TtsAudioType: ) -> tts.TtsAudioType:
"""Load TTS data.""" """Load TTS data."""
return ("mp3", b"") return ("mp3", b"")

View File

@ -76,7 +76,7 @@ class BaseProvider:
return ["voice", "age"] return ["voice", "age"]
def get_tts_audio( def get_tts_audio(
self, message: str, language: str, options: dict[str, Any] | None = None self, message: str, language: str, options: dict[str, Any]
) -> TtsAudioType: ) -> TtsAudioType:
"""Load TTS dat.""" """Load TTS dat."""
return ("mp3", b"") return ("mp3", b"")

View File

@ -1021,7 +1021,7 @@ class MockProviderBoom(MockProvider):
"""Mock provider that blows up.""" """Mock provider that blows up."""
def get_tts_audio( def get_tts_audio(
self, message: str, language: str, options: dict[str, Any] | None = None self, message: str, language: str, options: dict[str, Any]
) -> tts.TtsAudioType: ) -> tts.TtsAudioType:
"""Load TTS dat.""" """Load TTS dat."""
# This should not be called, data should be fetched from cache # This should not be called, data should be fetched from cache
@ -1032,7 +1032,7 @@ class MockEntityBoom(MockTTSEntity):
"""Mock entity that blows up.""" """Mock entity that blows up."""
def get_tts_audio( def get_tts_audio(
self, message: str, language: str, options: dict[str, Any] | None = None self, message: str, language: str, options: dict[str, Any]
) -> tts.TtsAudioType: ) -> tts.TtsAudioType:
"""Load TTS dat.""" """Load TTS dat."""
# This should not be called, data should be fetched from cache # This should not be called, data should be fetched from cache
@ -1116,7 +1116,7 @@ class MockProviderEmpty(MockProvider):
"""Mock provider with empty get_tts_audio.""" """Mock provider with empty get_tts_audio."""
def get_tts_audio( def get_tts_audio(
self, message: str, language: str, options: dict[str, Any] | None = None self, message: str, language: str, options: dict[str, Any]
) -> tts.TtsAudioType: ) -> tts.TtsAudioType:
"""Load TTS dat.""" """Load TTS dat."""
return (None, None) return (None, None)
@ -1126,7 +1126,7 @@ class MockEntityEmpty(MockTTSEntity):
"""Mock entity with empty get_tts_audio.""" """Mock entity with empty get_tts_audio."""
def get_tts_audio( def get_tts_audio(
self, message: str, language: str, options: dict[str, Any] | None = None self, message: str, language: str, options: dict[str, Any]
) -> tts.TtsAudioType: ) -> tts.TtsAudioType:
"""Load TTS dat.""" """Load TTS dat."""
return (None, None) return (None, None)
@ -1486,7 +1486,7 @@ async def test_legacy_fetching_in_async(
return {tts.ATTR_AUDIO_OUTPUT: "mp3"} return {tts.ATTR_AUDIO_OUTPUT: "mp3"}
async def async_get_tts_audio( async def async_get_tts_audio(
self, message: str, language: str, options: dict[str, Any] | None = None self, message: str, language: str, options: dict[str, Any]
) -> tts.TtsAudioType: ) -> tts.TtsAudioType:
return ("mp3", await tts_audio) return ("mp3", await tts_audio)
@ -1559,7 +1559,7 @@ async def test_fetching_in_async(
return {tts.ATTR_AUDIO_OUTPUT: "mp3"} return {tts.ATTR_AUDIO_OUTPUT: "mp3"}
async def async_get_tts_audio( async def async_get_tts_audio(
self, message: str, language: str, options: dict[str, Any] | None = None self, message: str, language: str, options: dict[str, Any]
) -> tts.TtsAudioType: ) -> tts.TtsAudioType:
return ("mp3", await tts_audio) return ("mp3", await tts_audio)

View File

@ -103,7 +103,7 @@ async def test_legacy_resolving(hass: HomeAssistant, mock_provider: MSProvider)
message, language = mock_get_tts_audio.mock_calls[0][1] message, language = mock_get_tts_audio.mock_calls[0][1]
assert message == "Hello World" assert message == "Hello World"
assert language == "en_US" assert language == "en_US"
assert mock_get_tts_audio.mock_calls[0][2]["options"] is None assert mock_get_tts_audio.mock_calls[0][2]["options"] == {}
# Pass language and options # Pass language and options
mock_get_tts_audio.reset_mock() mock_get_tts_audio.reset_mock()
@ -138,7 +138,7 @@ async def test_resolving(hass: HomeAssistant, mock_tts_entity: MSEntity) -> None
message, language = mock_get_tts_audio.mock_calls[0][1] message, language = mock_get_tts_audio.mock_calls[0][1]
assert message == "Hello World" assert message == "Hello World"
assert language == "en_US" assert language == "en_US"
assert mock_get_tts_audio.mock_calls[0][2]["options"] is None assert mock_get_tts_audio.mock_calls[0][2]["options"] == {}
# Pass language and options # Pass language and options
mock_get_tts_audio.reset_mock() mock_get_tts_audio.reset_mock()