mirror of
https://github.com/home-assistant/core.git
synced 2025-04-24 01:08:12 +00:00
Minor adjustment of tts typing (#93450)
This commit is contained in:
parent
68379dd55a
commit
30d9d7d905
@ -167,12 +167,9 @@ class AmazonPollyProvider(Provider):
|
||||
self,
|
||||
message: str,
|
||||
language: str,
|
||||
options: dict[str, Any] | None = None,
|
||||
options: dict[str, Any],
|
||||
) -> TtsAudioType:
|
||||
"""Request TTS file from Polly."""
|
||||
if options is None or language is None:
|
||||
_LOGGER.debug("language and/or options were missing")
|
||||
return None, None
|
||||
voice_id = options.get(CONF_VOICE, self.default_voice)
|
||||
voice_in_dict = self.all_voices[voice_id]
|
||||
if language != voice_in_dict.get("LanguageCode"):
|
||||
|
@ -104,7 +104,7 @@ class BaiduTTSProvider(Provider):
|
||||
"""Return a list of supported options."""
|
||||
return SUPPORTED_OPTIONS
|
||||
|
||||
def get_tts_audio(self, message, language, options=None):
|
||||
def get_tts_audio(self, message, language, options):
|
||||
"""Load TTS from BaiduTTS."""
|
||||
|
||||
aip_speech = AipSpeech(
|
||||
@ -113,14 +113,11 @@ class BaiduTTSProvider(Provider):
|
||||
self._app_data["secretkey"],
|
||||
)
|
||||
|
||||
if options is None:
|
||||
result = aip_speech.synthesis(message, language, 1, self._speech_conf_data)
|
||||
else:
|
||||
speech_data = self._speech_conf_data.copy()
|
||||
for key, value in options.items():
|
||||
speech_data[_OPTIONS[key]] = value
|
||||
speech_data = self._speech_conf_data.copy()
|
||||
for key, value in options.items():
|
||||
speech_data[_OPTIONS[key]] = value
|
||||
|
||||
result = aip_speech.synthesis(message, language, 1, speech_data)
|
||||
result = aip_speech.synthesis(message, language, 1, speech_data)
|
||||
|
||||
if isinstance(result, dict):
|
||||
_LOGGER.error(
|
||||
|
@ -134,12 +134,11 @@ class CloudProvider(Provider):
|
||||
}
|
||||
|
||||
async def async_get_tts_audio(
|
||||
self, message: str, language: str, options: dict[str, Any] | None = None
|
||||
self, message: str, language: str, options: dict[str, Any]
|
||||
) -> TtsAudioType:
|
||||
"""Load TTS from NabuCasa Cloud."""
|
||||
# Process TTS
|
||||
try:
|
||||
assert options is not None
|
||||
data = await self.cloud.voice.process_tts(
|
||||
text=message,
|
||||
language=language,
|
||||
|
@ -57,7 +57,7 @@ class DemoProvider(Provider):
|
||||
return ["voice", "age"]
|
||||
|
||||
def get_tts_audio(
|
||||
self, message: str, language: str, options: dict[str, Any] | None = None
|
||||
self, message: str, language: str, options: dict[str, Any]
|
||||
) -> TtsAudioType:
|
||||
"""Load TTS from demo."""
|
||||
filename = os.path.join(os.path.dirname(__file__), "tts.mp3")
|
||||
|
@ -241,7 +241,7 @@ class GoogleCloudTTSProvider(Provider):
|
||||
CONF_TEXT_TYPE: self._text_type,
|
||||
}
|
||||
|
||||
async def async_get_tts_audio(self, message, language, options=None):
|
||||
async def async_get_tts_audio(self, message, language, options):
|
||||
"""Load TTS from google."""
|
||||
options_schema = vol.Schema(
|
||||
{
|
||||
|
@ -59,13 +59,13 @@ class GoogleProvider(Provider):
|
||||
"""Return a list of supported options."""
|
||||
return SUPPORT_OPTIONS
|
||||
|
||||
def get_tts_audio(self, message, language, options=None):
|
||||
def get_tts_audio(self, message, language, options):
|
||||
"""Load TTS from google."""
|
||||
tld = self._tld
|
||||
if language in MAP_LANG_TLD:
|
||||
tld = MAP_LANG_TLD[language].tld
|
||||
language = MAP_LANG_TLD[language].lang
|
||||
if options is not None and "tld" in options:
|
||||
if "tld" in options:
|
||||
tld = options["tld"]
|
||||
tts = gTTS(text=message, lang=language, tld=tld)
|
||||
mp3_data = BytesIO()
|
||||
|
@ -80,7 +80,7 @@ class MaryTTSProvider(Provider):
|
||||
"""Return a list of supported options."""
|
||||
return SUPPORT_OPTIONS
|
||||
|
||||
def get_tts_audio(self, message, language, options=None):
|
||||
def get_tts_audio(self, message, language, options):
|
||||
"""Load TTS from MaryTTS."""
|
||||
effects = options[CONF_EFFECT]
|
||||
|
||||
|
@ -176,7 +176,7 @@ class MicrosoftProvider(Provider):
|
||||
"""Return a dict include default options."""
|
||||
return {CONF_GENDER: self._gender, CONF_TYPE: self._type}
|
||||
|
||||
def get_tts_audio(self, message, language, options=None):
|
||||
def get_tts_audio(self, message, language, options):
|
||||
"""Load TTS from Microsoft."""
|
||||
if language is None:
|
||||
language = self._lang
|
||||
|
@ -46,7 +46,7 @@ class PicoProvider(Provider):
|
||||
"""Return list of supported languages."""
|
||||
return SUPPORT_LANGUAGES
|
||||
|
||||
def get_tts_audio(self, message, language, options=None):
|
||||
def get_tts_audio(self, message, language, options):
|
||||
"""Load TTS using pico2wave."""
|
||||
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmpf:
|
||||
fname = tmpf.name
|
||||
|
@ -364,7 +364,7 @@ class TextToSpeechEntity(RestoreEntity):
|
||||
|
||||
@final
|
||||
async def internal_async_get_tts_audio(
|
||||
self, message: str, language: str, options: dict[str, Any] | None = None
|
||||
self, message: str, language: str, options: dict[str, Any]
|
||||
) -> TtsAudioType:
|
||||
"""Process an audio stream to TTS service.
|
||||
|
||||
@ -377,13 +377,13 @@ class TextToSpeechEntity(RestoreEntity):
|
||||
)
|
||||
|
||||
def get_tts_audio(
|
||||
self, message: str, language: str, options: dict[str, Any] | None = None
|
||||
self, message: str, language: str, options: dict[str, Any]
|
||||
) -> TtsAudioType:
|
||||
"""Load tts audio file from the engine."""
|
||||
raise NotImplementedError()
|
||||
|
||||
async def async_get_tts_audio(
|
||||
self, message: str, language: str, options: dict[str, Any] | None = None
|
||||
self, message: str, language: str, options: dict[str, Any]
|
||||
) -> TtsAudioType:
|
||||
"""Load tts audio file from the engine.
|
||||
|
||||
@ -478,9 +478,9 @@ class SpeechManager:
|
||||
def process_options(
|
||||
self,
|
||||
engine_instance: TextToSpeechEntity | Provider,
|
||||
language: str | None = None,
|
||||
options: dict | None = None,
|
||||
) -> tuple[str, dict | None]:
|
||||
language: str | None,
|
||||
options: dict | None,
|
||||
) -> tuple[str, dict[str, Any]]:
|
||||
"""Validate and process options."""
|
||||
# Languages
|
||||
language = language or engine_instance.default_language
|
||||
@ -491,23 +491,18 @@ class SpeechManager:
|
||||
):
|
||||
raise HomeAssistantError(f"Language '{language}' not supported")
|
||||
|
||||
# Options
|
||||
if (default_options := engine_instance.default_options) and options:
|
||||
merged_options = dict(default_options)
|
||||
merged_options.update(options)
|
||||
options = merged_options
|
||||
if not options:
|
||||
options = None if default_options is None else dict(default_options)
|
||||
# Update default options with provided options
|
||||
merged_options = dict(engine_instance.default_options or {})
|
||||
merged_options.update(options or {})
|
||||
|
||||
if options is not None:
|
||||
supported_options = engine_instance.supported_options or []
|
||||
invalid_opts = [
|
||||
opt_name for opt_name in options if opt_name not in supported_options
|
||||
]
|
||||
if invalid_opts:
|
||||
raise HomeAssistantError(f"Invalid options found: {invalid_opts}")
|
||||
supported_options = engine_instance.supported_options or []
|
||||
invalid_opts = [
|
||||
opt_name for opt_name in merged_options if opt_name not in supported_options
|
||||
]
|
||||
if invalid_opts:
|
||||
raise HomeAssistantError(f"Invalid options found: {invalid_opts}")
|
||||
|
||||
return language, options
|
||||
return language, merged_options
|
||||
|
||||
async def async_get_url_path(
|
||||
self,
|
||||
@ -602,7 +597,7 @@ class SpeechManager:
|
||||
message: str,
|
||||
cache: bool,
|
||||
language: str,
|
||||
options: dict | None,
|
||||
options: dict[str, Any],
|
||||
) -> str:
|
||||
"""Receive TTS, store for view in cache and return filename.
|
||||
|
||||
|
@ -240,13 +240,13 @@ class Provider:
|
||||
return None
|
||||
|
||||
def get_tts_audio(
|
||||
self, message: str, language: str, options: dict[str, Any] | None = None
|
||||
self, message: str, language: str, options: dict[str, Any]
|
||||
) -> TtsAudioType:
|
||||
"""Load tts audio file from provider."""
|
||||
raise NotImplementedError()
|
||||
|
||||
async def async_get_tts_audio(
|
||||
self, message: str, language: str, options: dict[str, Any] | None = None
|
||||
self, message: str, language: str, options: dict[str, Any]
|
||||
) -> TtsAudioType:
|
||||
"""Load tts audio file from provider.
|
||||
|
||||
|
@ -187,7 +187,7 @@ class VoiceRSSProvider(Provider):
|
||||
"""Return list of supported languages."""
|
||||
return SUPPORT_LANGUAGES
|
||||
|
||||
async def async_get_tts_audio(self, message, language, options=None):
|
||||
async def async_get_tts_audio(self, message, language, options):
|
||||
"""Load TTS from VoiceRSS."""
|
||||
websession = async_get_clientsession(self.hass)
|
||||
form_data = self._form_data.copy()
|
||||
|
@ -180,7 +180,7 @@ class WatsonTTSProvider(Provider):
|
||||
"""Return a list of supported options."""
|
||||
return [CONF_VOICE]
|
||||
|
||||
def get_tts_audio(self, message, language=None, options=None):
|
||||
def get_tts_audio(self, message, language, options):
|
||||
"""Request TTS file from Watson TTS."""
|
||||
response = self.service.synthesize(
|
||||
text=message, accept=self.output_format, voice=options[CONF_VOICE]
|
||||
|
@ -94,7 +94,7 @@ class WyomingTtsProvider(tts.TextToSpeechEntity):
|
||||
"""Return a list of supported voices for a language."""
|
||||
return self._voices.get(language)
|
||||
|
||||
async def async_get_tts_audio(self, message, language, options=None):
|
||||
async def async_get_tts_audio(self, message, language, options):
|
||||
"""Load TTS from UNIX socket."""
|
||||
try:
|
||||
async with AsyncTcpClient(self.service.host, self.service.port) as client:
|
||||
@ -129,7 +129,7 @@ class WyomingTtsProvider(tts.TextToSpeechEntity):
|
||||
except (OSError, WyomingError):
|
||||
return (None, None)
|
||||
|
||||
if (options is None) or (options[tts.ATTR_AUDIO_OUTPUT] == "wav"):
|
||||
if options[tts.ATTR_AUDIO_OUTPUT] == "wav":
|
||||
return ("wav", data)
|
||||
|
||||
# Raw output (convert to 16Khz, 16-bit mono)
|
||||
|
@ -114,11 +114,10 @@ class YandexSpeechKitProvider(Provider):
|
||||
"""Return list of supported options."""
|
||||
return SUPPORTED_OPTIONS
|
||||
|
||||
async def async_get_tts_audio(self, message, language, options=None):
|
||||
async def async_get_tts_audio(self, message, language, options):
|
||||
"""Load TTS from yandex."""
|
||||
websession = async_get_clientsession(self.hass)
|
||||
actual_language = language
|
||||
options = options or {}
|
||||
|
||||
try:
|
||||
async with async_timeout.timeout(10):
|
||||
|
@ -2449,7 +2449,7 @@ _INHERITANCE_MATCH: dict[str, list[ClassTypeHintMatch]] = {
|
||||
),
|
||||
TypeHintMatch(
|
||||
function_name="get_tts_audio",
|
||||
arg_types={1: "str", 2: "str", 3: "dict[str, Any] | None"},
|
||||
arg_types={1: "str", 2: "str", 3: "dict[str, Any]"},
|
||||
return_type="TtsAudioType",
|
||||
has_async_counterpart=True,
|
||||
),
|
||||
|
@ -127,7 +127,7 @@ class MockTTSProvider(tts.Provider):
|
||||
return ["voice", "age", tts.ATTR_AUDIO_OUTPUT]
|
||||
|
||||
def get_tts_audio(
|
||||
self, message: str, language: str, options: dict[str, Any] | None = None
|
||||
self, message: str, language: str, options: dict[str, Any]
|
||||
) -> tts.TtsAudioType:
|
||||
"""Load TTS data."""
|
||||
return ("mp3", b"")
|
||||
|
@ -76,7 +76,7 @@ class BaseProvider:
|
||||
return ["voice", "age"]
|
||||
|
||||
def get_tts_audio(
|
||||
self, message: str, language: str, options: dict[str, Any] | None = None
|
||||
self, message: str, language: str, options: dict[str, Any]
|
||||
) -> TtsAudioType:
|
||||
"""Load TTS dat."""
|
||||
return ("mp3", b"")
|
||||
|
@ -1021,7 +1021,7 @@ class MockProviderBoom(MockProvider):
|
||||
"""Mock provider that blows up."""
|
||||
|
||||
def get_tts_audio(
|
||||
self, message: str, language: str, options: dict[str, Any] | None = None
|
||||
self, message: str, language: str, options: dict[str, Any]
|
||||
) -> tts.TtsAudioType:
|
||||
"""Load TTS dat."""
|
||||
# This should not be called, data should be fetched from cache
|
||||
@ -1032,7 +1032,7 @@ class MockEntityBoom(MockTTSEntity):
|
||||
"""Mock entity that blows up."""
|
||||
|
||||
def get_tts_audio(
|
||||
self, message: str, language: str, options: dict[str, Any] | None = None
|
||||
self, message: str, language: str, options: dict[str, Any]
|
||||
) -> tts.TtsAudioType:
|
||||
"""Load TTS dat."""
|
||||
# This should not be called, data should be fetched from cache
|
||||
@ -1116,7 +1116,7 @@ class MockProviderEmpty(MockProvider):
|
||||
"""Mock provider with empty get_tts_audio."""
|
||||
|
||||
def get_tts_audio(
|
||||
self, message: str, language: str, options: dict[str, Any] | None = None
|
||||
self, message: str, language: str, options: dict[str, Any]
|
||||
) -> tts.TtsAudioType:
|
||||
"""Load TTS dat."""
|
||||
return (None, None)
|
||||
@ -1126,7 +1126,7 @@ class MockEntityEmpty(MockTTSEntity):
|
||||
"""Mock entity with empty get_tts_audio."""
|
||||
|
||||
def get_tts_audio(
|
||||
self, message: str, language: str, options: dict[str, Any] | None = None
|
||||
self, message: str, language: str, options: dict[str, Any]
|
||||
) -> tts.TtsAudioType:
|
||||
"""Load TTS dat."""
|
||||
return (None, None)
|
||||
@ -1486,7 +1486,7 @@ async def test_legacy_fetching_in_async(
|
||||
return {tts.ATTR_AUDIO_OUTPUT: "mp3"}
|
||||
|
||||
async def async_get_tts_audio(
|
||||
self, message: str, language: str, options: dict[str, Any] | None = None
|
||||
self, message: str, language: str, options: dict[str, Any]
|
||||
) -> tts.TtsAudioType:
|
||||
return ("mp3", await tts_audio)
|
||||
|
||||
@ -1559,7 +1559,7 @@ async def test_fetching_in_async(
|
||||
return {tts.ATTR_AUDIO_OUTPUT: "mp3"}
|
||||
|
||||
async def async_get_tts_audio(
|
||||
self, message: str, language: str, options: dict[str, Any] | None = None
|
||||
self, message: str, language: str, options: dict[str, Any]
|
||||
) -> tts.TtsAudioType:
|
||||
return ("mp3", await tts_audio)
|
||||
|
||||
|
@ -103,7 +103,7 @@ async def test_legacy_resolving(hass: HomeAssistant, mock_provider: MSProvider)
|
||||
message, language = mock_get_tts_audio.mock_calls[0][1]
|
||||
assert message == "Hello World"
|
||||
assert language == "en_US"
|
||||
assert mock_get_tts_audio.mock_calls[0][2]["options"] is None
|
||||
assert mock_get_tts_audio.mock_calls[0][2]["options"] == {}
|
||||
|
||||
# Pass language and options
|
||||
mock_get_tts_audio.reset_mock()
|
||||
@ -138,7 +138,7 @@ async def test_resolving(hass: HomeAssistant, mock_tts_entity: MSEntity) -> None
|
||||
message, language = mock_get_tts_audio.mock_calls[0][1]
|
||||
assert message == "Hello World"
|
||||
assert language == "en_US"
|
||||
assert mock_get_tts_audio.mock_calls[0][2]["options"] is None
|
||||
assert mock_get_tts_audio.mock_calls[0][2]["options"] == {}
|
||||
|
||||
# Pass language and options
|
||||
mock_get_tts_audio.reset_mock()
|
||||
|
Loading…
x
Reference in New Issue
Block a user