Add voice styles to HA Cloud (#143605)

* Add voice styles to HA Cloud

* Add seperator and extract util
This commit is contained in:
Paulus Schoutsen 2025-04-24 16:23:15 -04:00 committed by GitHub
parent a584ccb8f7
commit fdcb88977a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 169 additions and 75 deletions

View File

@ -93,3 +93,5 @@ STT_ENTITY_UNIQUE_ID = "cloud-speech-to-text"
TTS_ENTITY_UNIQUE_ID = "cloud-text-to-speech"
LOGIN_MFA_TIMEOUT = 60
VOICE_STYLE_SEPERATOR = "||"

View File

@ -57,6 +57,7 @@ from .const import (
PREF_REMOTE_ALLOW_REMOTE_ENABLE,
PREF_TTS_DEFAULT_VOICE,
REQUEST_TIMEOUT,
VOICE_STYLE_SEPERATOR,
)
from .google_config import CLOUD_GOOGLE
from .repairs import async_manage_legacy_subscription_issue
@ -591,10 +592,21 @@ async def websocket_subscription(
def validate_language_voice(value: tuple[str, str]) -> tuple[str, str]:
"""Validate language and voice."""
language, voice = value
style: str | None
voice, _, style = voice.partition(VOICE_STYLE_SEPERATOR)
if not style:
style = None
if language not in TTS_VOICES:
raise vol.Invalid(f"Invalid language {language}")
if voice not in TTS_VOICES[language]:
if voice not in (language_info := TTS_VOICES[language]):
raise vol.Invalid(f"Invalid voice {voice} for language {language}")
voice_info = language_info[voice]
if style and (
isinstance(voice_info, str) or style not in voice_info.get("variants", [])
):
raise vol.Invalid(
f"Invalid style {style} for voice {voice} in language {language}"
)
return value
@ -1012,13 +1024,24 @@ def tts_info(
msg: dict[str, Any],
) -> None:
"""Fetch available tts info."""
connection.send_result(
msg["id"],
{
"languages": [
(language, voice)
for language, voices in TTS_VOICES.items()
for voice in voices
]
},
)
result = []
for language, voices in TTS_VOICES.items():
for voice_id, voice_info in voices.items():
if isinstance(voice_info, str):
result.append((language, voice_id, voice_info))
continue
name = voice_info["name"]
result.append((language, voice_id, name))
result.extend(
[
(
language,
f"{voice_id}{VOICE_STYLE_SEPERATOR}{variant}",
f"{name} ({variant})",
)
for variant in voice_info.get("variants", [])
]
)
connection.send_result(msg["id"], {"languages": result})

View File

@ -31,7 +31,13 @@ from homeassistant.setup import async_when_setup
from .assist_pipeline import async_migrate_cloud_pipeline_engine
from .client import CloudClient
from .const import DATA_CLOUD, DATA_PLATFORMS_SETUP, DOMAIN, TTS_ENTITY_UNIQUE_ID
from .const import (
DATA_CLOUD,
DATA_PLATFORMS_SETUP,
DOMAIN,
TTS_ENTITY_UNIQUE_ID,
VOICE_STYLE_SEPERATOR,
)
from .prefs import CloudPreferences
ATTR_GENDER = "gender"
@ -195,6 +201,39 @@ DEFAULT_VOICES = {
_LOGGER = logging.getLogger(__name__)
@callback
def _prepare_voice_args(
*,
hass: HomeAssistant,
language: str,
voice: str,
gender: str | None,
) -> dict:
"""Prepare voice arguments."""
gender = handle_deprecated_gender(hass, gender)
style: str | None
original_voice, _, style = voice.partition(VOICE_STYLE_SEPERATOR)
if not style:
style = None
updated_voice = handle_deprecated_voice(hass, original_voice)
if updated_voice not in TTS_VOICES[language]:
default_voice = DEFAULT_VOICES[language]
_LOGGER.debug(
"Unsupported voice %s detected, falling back to default %s for %s",
voice,
default_voice,
language,
)
updated_voice = default_voice
return {
"language": language,
"voice": updated_voice,
"gender": gender,
"style": style,
}
def _deprecated_platform(value: str) -> str:
"""Validate if platform is deprecated."""
if value == DOMAIN:
@ -332,42 +371,59 @@ class CloudTTSEntity(TextToSpeechEntity):
"""Return a list of supported voices for a language."""
if not (voices := TTS_VOICES.get(language)):
return None
return [
Voice(
voice,
voice_info["name"] if isinstance(voice_info, dict) else voice_info,
result = []
for voice_id, voice_info in voices.items():
if isinstance(voice_info, str):
result.append(
Voice(
voice_id,
voice_info,
)
)
continue
name = voice_info["name"]
result.append(
Voice(
voice_id,
name,
)
)
for voice, voice_info in voices.items()
]
result.extend(
[
Voice(
f"{voice_id}{VOICE_STYLE_SEPERATOR}{variant}",
f"{name} ({variant})",
)
for variant in voice_info.get("variants", [])
]
)
return result
async def async_get_tts_audio(
self, message: str, language: str, options: dict[str, Any]
) -> TtsAudioType:
"""Load TTS from Home Assistant Cloud."""
gender: Gender | str | None = options.get(ATTR_GENDER)
gender = handle_deprecated_gender(self.hass, gender)
original_voice: str = options.get(
ATTR_VOICE,
self._voice if language == self._language else DEFAULT_VOICES[language],
)
voice = handle_deprecated_voice(self.hass, original_voice)
if voice not in TTS_VOICES[language]:
default_voice = DEFAULT_VOICES[language]
_LOGGER.debug(
"Unsupported voice %s detected, falling back to default %s for %s",
voice,
default_voice,
language,
)
voice = default_voice
# Process TTS
try:
data = await self.cloud.voice.process_tts(
text=message,
language=language,
gender=gender,
voice=voice,
output=options[ATTR_AUDIO_OUTPUT],
**_prepare_voice_args(
hass=self.hass,
language=language,
voice=options.get(
ATTR_VOICE,
self._voice
if language == self._language
else DEFAULT_VOICES[language],
),
gender=options.get(ATTR_GENDER),
),
)
except VoiceError as err:
_LOGGER.error("Voice error: %s", err)
@ -411,13 +467,38 @@ class CloudProvider(Provider):
"""Return a list of supported voices for a language."""
if not (voices := TTS_VOICES.get(language)):
return None
return [
Voice(
voice,
voice_info["name"] if isinstance(voice_info, dict) else voice_info,
result = []
for voice_id, voice_info in voices.items():
if isinstance(voice_info, str):
result.append(
Voice(
voice_id,
voice_info,
)
)
continue
name = voice_info["name"]
result.append(
Voice(
voice_id,
name,
)
)
for voice, voice_info in voices.items()
]
result.extend(
[
Voice(
f"{voice_id}{VOICE_STYLE_SEPERATOR}{variant}",
f"{name} ({variant})",
)
for variant in voice_info.get("variants", [])
]
)
return result
@property
def default_options(self) -> dict[str, str]:
@ -431,30 +512,22 @@ class CloudProvider(Provider):
) -> TtsAudioType:
"""Load TTS from Home Assistant Cloud."""
assert self.hass is not None
gender: Gender | str | None = options.get(ATTR_GENDER)
gender = handle_deprecated_gender(self.hass, gender)
original_voice: str = options.get(
ATTR_VOICE,
self._voice if language == self._language else DEFAULT_VOICES[language],
)
voice = handle_deprecated_voice(self.hass, original_voice)
if voice not in TTS_VOICES[language]:
default_voice = DEFAULT_VOICES[language]
_LOGGER.debug(
"Unsupported voice %s detected, falling back to default %s for %s",
voice,
default_voice,
language,
)
voice = default_voice
# Process TTS
try:
data = await self.cloud.voice.process_tts(
text=message,
language=language,
gender=gender,
voice=voice,
output=options[ATTR_AUDIO_OUTPUT],
**_prepare_voice_args(
hass=self.hass,
language=language,
voice=options.get(
ATTR_VOICE,
self._voice
if language == self._language
else DEFAULT_VOICES[language],
),
gender=options.get(ATTR_GENDER),
),
)
except VoiceError as err:
_LOGGER.error("Voice error: %s", err)

View File

@ -4,7 +4,6 @@ from collections.abc import Callable, Coroutine
from copy import deepcopy
import datetime
from http import HTTPStatus
import json
import logging
from typing import Any
from unittest.mock import AsyncMock, MagicMock, Mock, PropertyMock, patch
@ -20,7 +19,6 @@ from hass_nabucasa.auth import (
)
from hass_nabucasa.const import STATE_CONNECTED
from hass_nabucasa.remote import CertificateStatus
from hass_nabucasa.voice_data import TTS_VOICES
import pytest
from syrupy.assertion import SnapshotAssertion
@ -31,6 +29,7 @@ from homeassistant.components.alexa import errors as alexa_errors
from homeassistant.components.alexa.entities import LightCapabilities
from homeassistant.components.assist_pipeline.pipeline import STORAGE_KEY
from homeassistant.components.cloud.const import DEFAULT_EXPOSED_DOMAINS, DOMAIN
from homeassistant.components.cloud.http_api import validate_language_voice
from homeassistant.components.google_assistant.helpers import GoogleEntity
from homeassistant.components.homeassistant import exposed_entities
from homeassistant.components.websocket_api import ERR_INVALID_FORMAT
@ -1822,17 +1821,14 @@ async def test_tts_info(
response = await client.receive_json()
assert response["success"]
assert response["result"] == {
"languages": json.loads(
json.dumps(
[
(language, voice)
for language, voices in TTS_VOICES.items()
for voice in voices
]
)
)
}
assert "languages" in response["result"]
assert all(len(lang) for lang in response["result"]["languages"])
assert len(response["result"]["languages"]) > 300
assert (
len([lang for lang in response["result"]["languages"] if "||" in lang[1]]) > 100
)
for lang in response["result"]["languages"]:
assert validate_language_voice(lang[:2])
@pytest.mark.parametrize(