From 822660732b0b47baf264708eb582e7e7e22133e0 Mon Sep 17 00:00:00 2001 From: Jakob Schlyter Date: Tue, 3 Sep 2024 15:45:37 +0200 Subject: [PATCH] Support setting Amazon Polly engine in service call (#120226) --- homeassistant/components/amazon_polly/tts.py | 23 ++++++++++++++++---- 1 file changed, 19 insertions(+), 4 deletions(-) diff --git a/homeassistant/components/amazon_polly/tts.py b/homeassistant/components/amazon_polly/tts.py index 1fc972fa3a1..62852848a9c 100644 --- a/homeassistant/components/amazon_polly/tts.py +++ b/homeassistant/components/amazon_polly/tts.py @@ -2,6 +2,7 @@ from __future__ import annotations +from collections import defaultdict import logging from typing import Any, Final @@ -114,6 +115,8 @@ def get_engine( all_voices: dict[str, dict[str, str]] = {} + all_engines: dict[str, set[str]] = defaultdict(set) + all_voices_req = polly_client.describe_voices() for voice in all_voices_req.get("Voices", []): @@ -124,8 +127,12 @@ def get_engine( language_code: str | None = voice.get("LanguageCode") if language_code is not None and language_code not in supported_languages: supported_languages.append(language_code) + for engine in voice.get("SupportedEngines"): + all_engines[engine].add(voice_id) - return AmazonPollyProvider(polly_client, config, supported_languages, all_voices) + return AmazonPollyProvider( + polly_client, config, supported_languages, all_voices, all_engines + ) class AmazonPollyProvider(Provider): @@ -137,13 +144,16 @@ class AmazonPollyProvider(Provider): config: ConfigType, supported_languages: list[str], all_voices: dict[str, dict[str, str]], + all_engines: dict[str, set[str]], ) -> None: """Initialize Amazon Polly provider for TTS.""" self.client = polly_client self.config = config self.supported_langs = supported_languages self.all_voices = all_voices + self.all_engines = all_engines self.default_voice: str = self.config[CONF_VOICE] + self.default_engine: str = self.config[CONF_ENGINE] self.name = "Amazon Polly" @property @@ -159,12 +169,12 @@ class AmazonPollyProvider(Provider): @property def default_options(self) -> dict[str, str]: """Return dict include default options.""" - return {CONF_VOICE: self.default_voice} + return {CONF_VOICE: self.default_voice, CONF_ENGINE: self.default_engine} @property def supported_options(self) -> list[str]: """Return a list of supported options.""" - return [CONF_VOICE] + return [CONF_VOICE, CONF_ENGINE] def get_tts_audio( self, @@ -179,9 +189,14 @@ class AmazonPollyProvider(Provider): _LOGGER.error("%s does not support the %s language", voice_id, language) return None, None + engine = options.get(CONF_ENGINE, self.default_engine) + if voice_id not in self.all_engines[engine]: + _LOGGER.error("%s does not support the %s engine", voice_id, engine) + return None, None + _LOGGER.debug("Requesting TTS file for text: %s", message) resp = self.client.synthesize_speech( - Engine=self.config[CONF_ENGINE], + Engine=engine, OutputFormat=self.config[CONF_OUTPUT_FORMAT], SampleRate=self.config[CONF_SAMPLE_RATE], Text=message,