From fdcb88977a82b5578acdcf4e1b57423ce8b06626 Mon Sep 17 00:00:00 2001 From: Paulus Schoutsen Date: Thu, 24 Apr 2025 16:23:15 -0400 Subject: [PATCH] Add voice styles to HA Cloud (#143605) * Add voice styles to HA Cloud * Add seperator and extract util --- homeassistant/components/cloud/const.py | 2 + homeassistant/components/cloud/http_api.py | 45 ++++-- homeassistant/components/cloud/tts.py | 175 +++++++++++++++------ tests/components/cloud/test_http_api.py | 22 ++- 4 files changed, 169 insertions(+), 75 deletions(-) diff --git a/homeassistant/components/cloud/const.py b/homeassistant/components/cloud/const.py index e0c15c74cab..9a977d2a5b9 100644 --- a/homeassistant/components/cloud/const.py +++ b/homeassistant/components/cloud/const.py @@ -93,3 +93,5 @@ STT_ENTITY_UNIQUE_ID = "cloud-speech-to-text" TTS_ENTITY_UNIQUE_ID = "cloud-text-to-speech" LOGIN_MFA_TIMEOUT = 60 + +VOICE_STYLE_SEPERATOR = "||" diff --git a/homeassistant/components/cloud/http_api.py b/homeassistant/components/cloud/http_api.py index 9226110bca2..7c7cb925e4f 100644 --- a/homeassistant/components/cloud/http_api.py +++ b/homeassistant/components/cloud/http_api.py @@ -57,6 +57,7 @@ from .const import ( PREF_REMOTE_ALLOW_REMOTE_ENABLE, PREF_TTS_DEFAULT_VOICE, REQUEST_TIMEOUT, + VOICE_STYLE_SEPERATOR, ) from .google_config import CLOUD_GOOGLE from .repairs import async_manage_legacy_subscription_issue @@ -591,10 +592,21 @@ async def websocket_subscription( def validate_language_voice(value: tuple[str, str]) -> tuple[str, str]: """Validate language and voice.""" language, voice = value + style: str | None + voice, _, style = voice.partition(VOICE_STYLE_SEPERATOR) + if not style: + style = None if language not in TTS_VOICES: raise vol.Invalid(f"Invalid language {language}") - if voice not in TTS_VOICES[language]: + if voice not in (language_info := TTS_VOICES[language]): raise vol.Invalid(f"Invalid voice {voice} for language {language}") + voice_info = language_info[voice] + if style and ( + isinstance(voice_info, str) or style not in voice_info.get("variants", []) + ): + raise vol.Invalid( + f"Invalid style {style} for voice {voice} in language {language}" + ) return value @@ -1012,13 +1024,24 @@ def tts_info( msg: dict[str, Any], ) -> None: """Fetch available tts info.""" - connection.send_result( - msg["id"], - { - "languages": [ - (language, voice) - for language, voices in TTS_VOICES.items() - for voice in voices - ] - }, - ) + result = [] + for language, voices in TTS_VOICES.items(): + for voice_id, voice_info in voices.items(): + if isinstance(voice_info, str): + result.append((language, voice_id, voice_info)) + continue + + name = voice_info["name"] + result.append((language, voice_id, name)) + result.extend( + [ + ( + language, + f"{voice_id}{VOICE_STYLE_SEPERATOR}{variant}", + f"{name} ({variant})", + ) + for variant in voice_info.get("variants", []) + ] + ) + + connection.send_result(msg["id"], {"languages": result}) diff --git a/homeassistant/components/cloud/tts.py b/homeassistant/components/cloud/tts.py index b5e4dc1cd84..ca3e0719998 100644 --- a/homeassistant/components/cloud/tts.py +++ b/homeassistant/components/cloud/tts.py @@ -31,7 +31,13 @@ from homeassistant.setup import async_when_setup from .assist_pipeline import async_migrate_cloud_pipeline_engine from .client import CloudClient -from .const import DATA_CLOUD, DATA_PLATFORMS_SETUP, DOMAIN, TTS_ENTITY_UNIQUE_ID +from .const import ( + DATA_CLOUD, + DATA_PLATFORMS_SETUP, + DOMAIN, + TTS_ENTITY_UNIQUE_ID, + VOICE_STYLE_SEPERATOR, +) from .prefs import CloudPreferences ATTR_GENDER = "gender" @@ -195,6 +201,39 @@ DEFAULT_VOICES = { _LOGGER = logging.getLogger(__name__) +@callback +def _prepare_voice_args( + *, + hass: HomeAssistant, + language: str, + voice: str, + gender: str | None, +) -> dict: + """Prepare voice arguments.""" + gender = handle_deprecated_gender(hass, gender) + style: str | None + original_voice, _, style = voice.partition(VOICE_STYLE_SEPERATOR) + if not style: + style = None + updated_voice = handle_deprecated_voice(hass, original_voice) + if updated_voice not in TTS_VOICES[language]: + default_voice = DEFAULT_VOICES[language] + _LOGGER.debug( + "Unsupported voice %s detected, falling back to default %s for %s", + voice, + default_voice, + language, + ) + updated_voice = default_voice + + return { + "language": language, + "voice": updated_voice, + "gender": gender, + "style": style, + } + + def _deprecated_platform(value: str) -> str: """Validate if platform is deprecated.""" if value == DOMAIN: @@ -332,42 +371,59 @@ class CloudTTSEntity(TextToSpeechEntity): """Return a list of supported voices for a language.""" if not (voices := TTS_VOICES.get(language)): return None - return [ - Voice( - voice, - voice_info["name"] if isinstance(voice_info, dict) else voice_info, + + result = [] + + for voice_id, voice_info in voices.items(): + if isinstance(voice_info, str): + result.append( + Voice( + voice_id, + voice_info, + ) + ) + continue + + name = voice_info["name"] + + result.append( + Voice( + voice_id, + name, + ) ) - for voice, voice_info in voices.items() - ] + result.extend( + [ + Voice( + f"{voice_id}{VOICE_STYLE_SEPERATOR}{variant}", + f"{name} ({variant})", + ) + for variant in voice_info.get("variants", []) + ] + ) + + return result async def async_get_tts_audio( self, message: str, language: str, options: dict[str, Any] ) -> TtsAudioType: """Load TTS from Home Assistant Cloud.""" - gender: Gender | str | None = options.get(ATTR_GENDER) - gender = handle_deprecated_gender(self.hass, gender) - original_voice: str = options.get( - ATTR_VOICE, - self._voice if language == self._language else DEFAULT_VOICES[language], - ) - voice = handle_deprecated_voice(self.hass, original_voice) - if voice not in TTS_VOICES[language]: - default_voice = DEFAULT_VOICES[language] - _LOGGER.debug( - "Unsupported voice %s detected, falling back to default %s for %s", - voice, - default_voice, - language, - ) - voice = default_voice # Process TTS try: data = await self.cloud.voice.process_tts( text=message, - language=language, - gender=gender, - voice=voice, output=options[ATTR_AUDIO_OUTPUT], + **_prepare_voice_args( + hass=self.hass, + language=language, + voice=options.get( + ATTR_VOICE, + self._voice + if language == self._language + else DEFAULT_VOICES[language], + ), + gender=options.get(ATTR_GENDER), + ), ) except VoiceError as err: _LOGGER.error("Voice error: %s", err) @@ -411,13 +467,38 @@ class CloudProvider(Provider): """Return a list of supported voices for a language.""" if not (voices := TTS_VOICES.get(language)): return None - return [ - Voice( - voice, - voice_info["name"] if isinstance(voice_info, dict) else voice_info, + + result = [] + + for voice_id, voice_info in voices.items(): + if isinstance(voice_info, str): + result.append( + Voice( + voice_id, + voice_info, + ) + ) + continue + + name = voice_info["name"] + + result.append( + Voice( + voice_id, + name, + ) ) - for voice, voice_info in voices.items() - ] + result.extend( + [ + Voice( + f"{voice_id}{VOICE_STYLE_SEPERATOR}{variant}", + f"{name} ({variant})", + ) + for variant in voice_info.get("variants", []) + ] + ) + + return result @property def default_options(self) -> dict[str, str]: @@ -431,30 +512,22 @@ class CloudProvider(Provider): ) -> TtsAudioType: """Load TTS from Home Assistant Cloud.""" assert self.hass is not None - gender: Gender | str | None = options.get(ATTR_GENDER) - gender = handle_deprecated_gender(self.hass, gender) - original_voice: str = options.get( - ATTR_VOICE, - self._voice if language == self._language else DEFAULT_VOICES[language], - ) - voice = handle_deprecated_voice(self.hass, original_voice) - if voice not in TTS_VOICES[language]: - default_voice = DEFAULT_VOICES[language] - _LOGGER.debug( - "Unsupported voice %s detected, falling back to default %s for %s", - voice, - default_voice, - language, - ) - voice = default_voice # Process TTS try: data = await self.cloud.voice.process_tts( text=message, - language=language, - gender=gender, - voice=voice, output=options[ATTR_AUDIO_OUTPUT], + **_prepare_voice_args( + hass=self.hass, + language=language, + voice=options.get( + ATTR_VOICE, + self._voice + if language == self._language + else DEFAULT_VOICES[language], + ), + gender=options.get(ATTR_GENDER), + ), ) except VoiceError as err: _LOGGER.error("Voice error: %s", err) diff --git a/tests/components/cloud/test_http_api.py b/tests/components/cloud/test_http_api.py index 73ec1aceb55..2722445445e 100644 --- a/tests/components/cloud/test_http_api.py +++ b/tests/components/cloud/test_http_api.py @@ -4,7 +4,6 @@ from collections.abc import Callable, Coroutine from copy import deepcopy import datetime from http import HTTPStatus -import json import logging from typing import Any from unittest.mock import AsyncMock, MagicMock, Mock, PropertyMock, patch @@ -20,7 +19,6 @@ from hass_nabucasa.auth import ( ) from hass_nabucasa.const import STATE_CONNECTED from hass_nabucasa.remote import CertificateStatus -from hass_nabucasa.voice_data import TTS_VOICES import pytest from syrupy.assertion import SnapshotAssertion @@ -31,6 +29,7 @@ from homeassistant.components.alexa import errors as alexa_errors from homeassistant.components.alexa.entities import LightCapabilities from homeassistant.components.assist_pipeline.pipeline import STORAGE_KEY from homeassistant.components.cloud.const import DEFAULT_EXPOSED_DOMAINS, DOMAIN +from homeassistant.components.cloud.http_api import validate_language_voice from homeassistant.components.google_assistant.helpers import GoogleEntity from homeassistant.components.homeassistant import exposed_entities from homeassistant.components.websocket_api import ERR_INVALID_FORMAT @@ -1822,17 +1821,14 @@ async def test_tts_info( response = await client.receive_json() assert response["success"] - assert response["result"] == { - "languages": json.loads( - json.dumps( - [ - (language, voice) - for language, voices in TTS_VOICES.items() - for voice in voices - ] - ) - ) - } + assert "languages" in response["result"] + assert all(len(lang) for lang in response["result"]["languages"]) + assert len(response["result"]["languages"]) > 300 + assert ( + len([lang for lang in response["result"]["languages"] if "||" in lang[1]]) > 100 + ) + for lang in response["result"]["languages"]: + assert validate_language_voice(lang[:2]) @pytest.mark.parametrize(