mirror of
https://github.com/home-assistant/core.git
synced 2025-07-22 04:37:06 +00:00
Add voice styles to HA Cloud (#143605)
* Add voice styles to HA Cloud * Add seperator and extract util
This commit is contained in:
parent
a584ccb8f7
commit
fdcb88977a
@ -93,3 +93,5 @@ STT_ENTITY_UNIQUE_ID = "cloud-speech-to-text"
|
||||
TTS_ENTITY_UNIQUE_ID = "cloud-text-to-speech"
|
||||
|
||||
LOGIN_MFA_TIMEOUT = 60
|
||||
|
||||
VOICE_STYLE_SEPERATOR = "||"
|
||||
|
@ -57,6 +57,7 @@ from .const import (
|
||||
PREF_REMOTE_ALLOW_REMOTE_ENABLE,
|
||||
PREF_TTS_DEFAULT_VOICE,
|
||||
REQUEST_TIMEOUT,
|
||||
VOICE_STYLE_SEPERATOR,
|
||||
)
|
||||
from .google_config import CLOUD_GOOGLE
|
||||
from .repairs import async_manage_legacy_subscription_issue
|
||||
@ -591,10 +592,21 @@ async def websocket_subscription(
|
||||
def validate_language_voice(value: tuple[str, str]) -> tuple[str, str]:
|
||||
"""Validate language and voice."""
|
||||
language, voice = value
|
||||
style: str | None
|
||||
voice, _, style = voice.partition(VOICE_STYLE_SEPERATOR)
|
||||
if not style:
|
||||
style = None
|
||||
if language not in TTS_VOICES:
|
||||
raise vol.Invalid(f"Invalid language {language}")
|
||||
if voice not in TTS_VOICES[language]:
|
||||
if voice not in (language_info := TTS_VOICES[language]):
|
||||
raise vol.Invalid(f"Invalid voice {voice} for language {language}")
|
||||
voice_info = language_info[voice]
|
||||
if style and (
|
||||
isinstance(voice_info, str) or style not in voice_info.get("variants", [])
|
||||
):
|
||||
raise vol.Invalid(
|
||||
f"Invalid style {style} for voice {voice} in language {language}"
|
||||
)
|
||||
return value
|
||||
|
||||
|
||||
@ -1012,13 +1024,24 @@ def tts_info(
|
||||
msg: dict[str, Any],
|
||||
) -> None:
|
||||
"""Fetch available tts info."""
|
||||
connection.send_result(
|
||||
msg["id"],
|
||||
{
|
||||
"languages": [
|
||||
(language, voice)
|
||||
for language, voices in TTS_VOICES.items()
|
||||
for voice in voices
|
||||
]
|
||||
},
|
||||
)
|
||||
result = []
|
||||
for language, voices in TTS_VOICES.items():
|
||||
for voice_id, voice_info in voices.items():
|
||||
if isinstance(voice_info, str):
|
||||
result.append((language, voice_id, voice_info))
|
||||
continue
|
||||
|
||||
name = voice_info["name"]
|
||||
result.append((language, voice_id, name))
|
||||
result.extend(
|
||||
[
|
||||
(
|
||||
language,
|
||||
f"{voice_id}{VOICE_STYLE_SEPERATOR}{variant}",
|
||||
f"{name} ({variant})",
|
||||
)
|
||||
for variant in voice_info.get("variants", [])
|
||||
]
|
||||
)
|
||||
|
||||
connection.send_result(msg["id"], {"languages": result})
|
||||
|
@ -31,7 +31,13 @@ from homeassistant.setup import async_when_setup
|
||||
|
||||
from .assist_pipeline import async_migrate_cloud_pipeline_engine
|
||||
from .client import CloudClient
|
||||
from .const import DATA_CLOUD, DATA_PLATFORMS_SETUP, DOMAIN, TTS_ENTITY_UNIQUE_ID
|
||||
from .const import (
|
||||
DATA_CLOUD,
|
||||
DATA_PLATFORMS_SETUP,
|
||||
DOMAIN,
|
||||
TTS_ENTITY_UNIQUE_ID,
|
||||
VOICE_STYLE_SEPERATOR,
|
||||
)
|
||||
from .prefs import CloudPreferences
|
||||
|
||||
ATTR_GENDER = "gender"
|
||||
@ -195,6 +201,39 @@ DEFAULT_VOICES = {
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@callback
|
||||
def _prepare_voice_args(
|
||||
*,
|
||||
hass: HomeAssistant,
|
||||
language: str,
|
||||
voice: str,
|
||||
gender: str | None,
|
||||
) -> dict:
|
||||
"""Prepare voice arguments."""
|
||||
gender = handle_deprecated_gender(hass, gender)
|
||||
style: str | None
|
||||
original_voice, _, style = voice.partition(VOICE_STYLE_SEPERATOR)
|
||||
if not style:
|
||||
style = None
|
||||
updated_voice = handle_deprecated_voice(hass, original_voice)
|
||||
if updated_voice not in TTS_VOICES[language]:
|
||||
default_voice = DEFAULT_VOICES[language]
|
||||
_LOGGER.debug(
|
||||
"Unsupported voice %s detected, falling back to default %s for %s",
|
||||
voice,
|
||||
default_voice,
|
||||
language,
|
||||
)
|
||||
updated_voice = default_voice
|
||||
|
||||
return {
|
||||
"language": language,
|
||||
"voice": updated_voice,
|
||||
"gender": gender,
|
||||
"style": style,
|
||||
}
|
||||
|
||||
|
||||
def _deprecated_platform(value: str) -> str:
|
||||
"""Validate if platform is deprecated."""
|
||||
if value == DOMAIN:
|
||||
@ -332,42 +371,59 @@ class CloudTTSEntity(TextToSpeechEntity):
|
||||
"""Return a list of supported voices for a language."""
|
||||
if not (voices := TTS_VOICES.get(language)):
|
||||
return None
|
||||
return [
|
||||
Voice(
|
||||
voice,
|
||||
voice_info["name"] if isinstance(voice_info, dict) else voice_info,
|
||||
|
||||
result = []
|
||||
|
||||
for voice_id, voice_info in voices.items():
|
||||
if isinstance(voice_info, str):
|
||||
result.append(
|
||||
Voice(
|
||||
voice_id,
|
||||
voice_info,
|
||||
)
|
||||
)
|
||||
continue
|
||||
|
||||
name = voice_info["name"]
|
||||
|
||||
result.append(
|
||||
Voice(
|
||||
voice_id,
|
||||
name,
|
||||
)
|
||||
)
|
||||
for voice, voice_info in voices.items()
|
||||
]
|
||||
result.extend(
|
||||
[
|
||||
Voice(
|
||||
f"{voice_id}{VOICE_STYLE_SEPERATOR}{variant}",
|
||||
f"{name} ({variant})",
|
||||
)
|
||||
for variant in voice_info.get("variants", [])
|
||||
]
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
async def async_get_tts_audio(
|
||||
self, message: str, language: str, options: dict[str, Any]
|
||||
) -> TtsAudioType:
|
||||
"""Load TTS from Home Assistant Cloud."""
|
||||
gender: Gender | str | None = options.get(ATTR_GENDER)
|
||||
gender = handle_deprecated_gender(self.hass, gender)
|
||||
original_voice: str = options.get(
|
||||
ATTR_VOICE,
|
||||
self._voice if language == self._language else DEFAULT_VOICES[language],
|
||||
)
|
||||
voice = handle_deprecated_voice(self.hass, original_voice)
|
||||
if voice not in TTS_VOICES[language]:
|
||||
default_voice = DEFAULT_VOICES[language]
|
||||
_LOGGER.debug(
|
||||
"Unsupported voice %s detected, falling back to default %s for %s",
|
||||
voice,
|
||||
default_voice,
|
||||
language,
|
||||
)
|
||||
voice = default_voice
|
||||
# Process TTS
|
||||
try:
|
||||
data = await self.cloud.voice.process_tts(
|
||||
text=message,
|
||||
language=language,
|
||||
gender=gender,
|
||||
voice=voice,
|
||||
output=options[ATTR_AUDIO_OUTPUT],
|
||||
**_prepare_voice_args(
|
||||
hass=self.hass,
|
||||
language=language,
|
||||
voice=options.get(
|
||||
ATTR_VOICE,
|
||||
self._voice
|
||||
if language == self._language
|
||||
else DEFAULT_VOICES[language],
|
||||
),
|
||||
gender=options.get(ATTR_GENDER),
|
||||
),
|
||||
)
|
||||
except VoiceError as err:
|
||||
_LOGGER.error("Voice error: %s", err)
|
||||
@ -411,13 +467,38 @@ class CloudProvider(Provider):
|
||||
"""Return a list of supported voices for a language."""
|
||||
if not (voices := TTS_VOICES.get(language)):
|
||||
return None
|
||||
return [
|
||||
Voice(
|
||||
voice,
|
||||
voice_info["name"] if isinstance(voice_info, dict) else voice_info,
|
||||
|
||||
result = []
|
||||
|
||||
for voice_id, voice_info in voices.items():
|
||||
if isinstance(voice_info, str):
|
||||
result.append(
|
||||
Voice(
|
||||
voice_id,
|
||||
voice_info,
|
||||
)
|
||||
)
|
||||
continue
|
||||
|
||||
name = voice_info["name"]
|
||||
|
||||
result.append(
|
||||
Voice(
|
||||
voice_id,
|
||||
name,
|
||||
)
|
||||
)
|
||||
for voice, voice_info in voices.items()
|
||||
]
|
||||
result.extend(
|
||||
[
|
||||
Voice(
|
||||
f"{voice_id}{VOICE_STYLE_SEPERATOR}{variant}",
|
||||
f"{name} ({variant})",
|
||||
)
|
||||
for variant in voice_info.get("variants", [])
|
||||
]
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
@property
|
||||
def default_options(self) -> dict[str, str]:
|
||||
@ -431,30 +512,22 @@ class CloudProvider(Provider):
|
||||
) -> TtsAudioType:
|
||||
"""Load TTS from Home Assistant Cloud."""
|
||||
assert self.hass is not None
|
||||
gender: Gender | str | None = options.get(ATTR_GENDER)
|
||||
gender = handle_deprecated_gender(self.hass, gender)
|
||||
original_voice: str = options.get(
|
||||
ATTR_VOICE,
|
||||
self._voice if language == self._language else DEFAULT_VOICES[language],
|
||||
)
|
||||
voice = handle_deprecated_voice(self.hass, original_voice)
|
||||
if voice not in TTS_VOICES[language]:
|
||||
default_voice = DEFAULT_VOICES[language]
|
||||
_LOGGER.debug(
|
||||
"Unsupported voice %s detected, falling back to default %s for %s",
|
||||
voice,
|
||||
default_voice,
|
||||
language,
|
||||
)
|
||||
voice = default_voice
|
||||
# Process TTS
|
||||
try:
|
||||
data = await self.cloud.voice.process_tts(
|
||||
text=message,
|
||||
language=language,
|
||||
gender=gender,
|
||||
voice=voice,
|
||||
output=options[ATTR_AUDIO_OUTPUT],
|
||||
**_prepare_voice_args(
|
||||
hass=self.hass,
|
||||
language=language,
|
||||
voice=options.get(
|
||||
ATTR_VOICE,
|
||||
self._voice
|
||||
if language == self._language
|
||||
else DEFAULT_VOICES[language],
|
||||
),
|
||||
gender=options.get(ATTR_GENDER),
|
||||
),
|
||||
)
|
||||
except VoiceError as err:
|
||||
_LOGGER.error("Voice error: %s", err)
|
||||
|
@ -4,7 +4,6 @@ from collections.abc import Callable, Coroutine
|
||||
from copy import deepcopy
|
||||
import datetime
|
||||
from http import HTTPStatus
|
||||
import json
|
||||
import logging
|
||||
from typing import Any
|
||||
from unittest.mock import AsyncMock, MagicMock, Mock, PropertyMock, patch
|
||||
@ -20,7 +19,6 @@ from hass_nabucasa.auth import (
|
||||
)
|
||||
from hass_nabucasa.const import STATE_CONNECTED
|
||||
from hass_nabucasa.remote import CertificateStatus
|
||||
from hass_nabucasa.voice_data import TTS_VOICES
|
||||
import pytest
|
||||
from syrupy.assertion import SnapshotAssertion
|
||||
|
||||
@ -31,6 +29,7 @@ from homeassistant.components.alexa import errors as alexa_errors
|
||||
from homeassistant.components.alexa.entities import LightCapabilities
|
||||
from homeassistant.components.assist_pipeline.pipeline import STORAGE_KEY
|
||||
from homeassistant.components.cloud.const import DEFAULT_EXPOSED_DOMAINS, DOMAIN
|
||||
from homeassistant.components.cloud.http_api import validate_language_voice
|
||||
from homeassistant.components.google_assistant.helpers import GoogleEntity
|
||||
from homeassistant.components.homeassistant import exposed_entities
|
||||
from homeassistant.components.websocket_api import ERR_INVALID_FORMAT
|
||||
@ -1822,17 +1821,14 @@ async def test_tts_info(
|
||||
response = await client.receive_json()
|
||||
|
||||
assert response["success"]
|
||||
assert response["result"] == {
|
||||
"languages": json.loads(
|
||||
json.dumps(
|
||||
[
|
||||
(language, voice)
|
||||
for language, voices in TTS_VOICES.items()
|
||||
for voice in voices
|
||||
]
|
||||
)
|
||||
)
|
||||
}
|
||||
assert "languages" in response["result"]
|
||||
assert all(len(lang) for lang in response["result"]["languages"])
|
||||
assert len(response["result"]["languages"]) > 300
|
||||
assert (
|
||||
len([lang for lang in response["result"]["languages"] if "||" in lang[1]]) > 100
|
||||
)
|
||||
for lang in response["result"]["languages"]:
|
||||
assert validate_language_voice(lang[:2])
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
|
Loading…
x
Reference in New Issue
Block a user