mirror of
https://github.com/home-assistant/core.git
synced 2025-07-23 13:17:32 +00:00
Add voice styles to HA Cloud (#143605)
* Add voice styles to HA Cloud * Add seperator and extract util
This commit is contained in:
parent
a584ccb8f7
commit
fdcb88977a
@ -93,3 +93,5 @@ STT_ENTITY_UNIQUE_ID = "cloud-speech-to-text"
|
|||||||
TTS_ENTITY_UNIQUE_ID = "cloud-text-to-speech"
|
TTS_ENTITY_UNIQUE_ID = "cloud-text-to-speech"
|
||||||
|
|
||||||
LOGIN_MFA_TIMEOUT = 60
|
LOGIN_MFA_TIMEOUT = 60
|
||||||
|
|
||||||
|
VOICE_STYLE_SEPERATOR = "||"
|
||||||
|
@ -57,6 +57,7 @@ from .const import (
|
|||||||
PREF_REMOTE_ALLOW_REMOTE_ENABLE,
|
PREF_REMOTE_ALLOW_REMOTE_ENABLE,
|
||||||
PREF_TTS_DEFAULT_VOICE,
|
PREF_TTS_DEFAULT_VOICE,
|
||||||
REQUEST_TIMEOUT,
|
REQUEST_TIMEOUT,
|
||||||
|
VOICE_STYLE_SEPERATOR,
|
||||||
)
|
)
|
||||||
from .google_config import CLOUD_GOOGLE
|
from .google_config import CLOUD_GOOGLE
|
||||||
from .repairs import async_manage_legacy_subscription_issue
|
from .repairs import async_manage_legacy_subscription_issue
|
||||||
@ -591,10 +592,21 @@ async def websocket_subscription(
|
|||||||
def validate_language_voice(value: tuple[str, str]) -> tuple[str, str]:
|
def validate_language_voice(value: tuple[str, str]) -> tuple[str, str]:
|
||||||
"""Validate language and voice."""
|
"""Validate language and voice."""
|
||||||
language, voice = value
|
language, voice = value
|
||||||
|
style: str | None
|
||||||
|
voice, _, style = voice.partition(VOICE_STYLE_SEPERATOR)
|
||||||
|
if not style:
|
||||||
|
style = None
|
||||||
if language not in TTS_VOICES:
|
if language not in TTS_VOICES:
|
||||||
raise vol.Invalid(f"Invalid language {language}")
|
raise vol.Invalid(f"Invalid language {language}")
|
||||||
if voice not in TTS_VOICES[language]:
|
if voice not in (language_info := TTS_VOICES[language]):
|
||||||
raise vol.Invalid(f"Invalid voice {voice} for language {language}")
|
raise vol.Invalid(f"Invalid voice {voice} for language {language}")
|
||||||
|
voice_info = language_info[voice]
|
||||||
|
if style and (
|
||||||
|
isinstance(voice_info, str) or style not in voice_info.get("variants", [])
|
||||||
|
):
|
||||||
|
raise vol.Invalid(
|
||||||
|
f"Invalid style {style} for voice {voice} in language {language}"
|
||||||
|
)
|
||||||
return value
|
return value
|
||||||
|
|
||||||
|
|
||||||
@ -1012,13 +1024,24 @@ def tts_info(
|
|||||||
msg: dict[str, Any],
|
msg: dict[str, Any],
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Fetch available tts info."""
|
"""Fetch available tts info."""
|
||||||
connection.send_result(
|
result = []
|
||||||
msg["id"],
|
for language, voices in TTS_VOICES.items():
|
||||||
{
|
for voice_id, voice_info in voices.items():
|
||||||
"languages": [
|
if isinstance(voice_info, str):
|
||||||
(language, voice)
|
result.append((language, voice_id, voice_info))
|
||||||
for language, voices in TTS_VOICES.items()
|
continue
|
||||||
for voice in voices
|
|
||||||
]
|
name = voice_info["name"]
|
||||||
},
|
result.append((language, voice_id, name))
|
||||||
|
result.extend(
|
||||||
|
[
|
||||||
|
(
|
||||||
|
language,
|
||||||
|
f"{voice_id}{VOICE_STYLE_SEPERATOR}{variant}",
|
||||||
|
f"{name} ({variant})",
|
||||||
)
|
)
|
||||||
|
for variant in voice_info.get("variants", [])
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
connection.send_result(msg["id"], {"languages": result})
|
||||||
|
@ -31,7 +31,13 @@ from homeassistant.setup import async_when_setup
|
|||||||
|
|
||||||
from .assist_pipeline import async_migrate_cloud_pipeline_engine
|
from .assist_pipeline import async_migrate_cloud_pipeline_engine
|
||||||
from .client import CloudClient
|
from .client import CloudClient
|
||||||
from .const import DATA_CLOUD, DATA_PLATFORMS_SETUP, DOMAIN, TTS_ENTITY_UNIQUE_ID
|
from .const import (
|
||||||
|
DATA_CLOUD,
|
||||||
|
DATA_PLATFORMS_SETUP,
|
||||||
|
DOMAIN,
|
||||||
|
TTS_ENTITY_UNIQUE_ID,
|
||||||
|
VOICE_STYLE_SEPERATOR,
|
||||||
|
)
|
||||||
from .prefs import CloudPreferences
|
from .prefs import CloudPreferences
|
||||||
|
|
||||||
ATTR_GENDER = "gender"
|
ATTR_GENDER = "gender"
|
||||||
@ -195,6 +201,39 @@ DEFAULT_VOICES = {
|
|||||||
_LOGGER = logging.getLogger(__name__)
|
_LOGGER = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@callback
|
||||||
|
def _prepare_voice_args(
|
||||||
|
*,
|
||||||
|
hass: HomeAssistant,
|
||||||
|
language: str,
|
||||||
|
voice: str,
|
||||||
|
gender: str | None,
|
||||||
|
) -> dict:
|
||||||
|
"""Prepare voice arguments."""
|
||||||
|
gender = handle_deprecated_gender(hass, gender)
|
||||||
|
style: str | None
|
||||||
|
original_voice, _, style = voice.partition(VOICE_STYLE_SEPERATOR)
|
||||||
|
if not style:
|
||||||
|
style = None
|
||||||
|
updated_voice = handle_deprecated_voice(hass, original_voice)
|
||||||
|
if updated_voice not in TTS_VOICES[language]:
|
||||||
|
default_voice = DEFAULT_VOICES[language]
|
||||||
|
_LOGGER.debug(
|
||||||
|
"Unsupported voice %s detected, falling back to default %s for %s",
|
||||||
|
voice,
|
||||||
|
default_voice,
|
||||||
|
language,
|
||||||
|
)
|
||||||
|
updated_voice = default_voice
|
||||||
|
|
||||||
|
return {
|
||||||
|
"language": language,
|
||||||
|
"voice": updated_voice,
|
||||||
|
"gender": gender,
|
||||||
|
"style": style,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
def _deprecated_platform(value: str) -> str:
|
def _deprecated_platform(value: str) -> str:
|
||||||
"""Validate if platform is deprecated."""
|
"""Validate if platform is deprecated."""
|
||||||
if value == DOMAIN:
|
if value == DOMAIN:
|
||||||
@ -332,42 +371,59 @@ class CloudTTSEntity(TextToSpeechEntity):
|
|||||||
"""Return a list of supported voices for a language."""
|
"""Return a list of supported voices for a language."""
|
||||||
if not (voices := TTS_VOICES.get(language)):
|
if not (voices := TTS_VOICES.get(language)):
|
||||||
return None
|
return None
|
||||||
return [
|
|
||||||
|
result = []
|
||||||
|
|
||||||
|
for voice_id, voice_info in voices.items():
|
||||||
|
if isinstance(voice_info, str):
|
||||||
|
result.append(
|
||||||
Voice(
|
Voice(
|
||||||
voice,
|
voice_id,
|
||||||
voice_info["name"] if isinstance(voice_info, dict) else voice_info,
|
voice_info,
|
||||||
)
|
)
|
||||||
for voice, voice_info in voices.items()
|
)
|
||||||
|
continue
|
||||||
|
|
||||||
|
name = voice_info["name"]
|
||||||
|
|
||||||
|
result.append(
|
||||||
|
Voice(
|
||||||
|
voice_id,
|
||||||
|
name,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
result.extend(
|
||||||
|
[
|
||||||
|
Voice(
|
||||||
|
f"{voice_id}{VOICE_STYLE_SEPERATOR}{variant}",
|
||||||
|
f"{name} ({variant})",
|
||||||
|
)
|
||||||
|
for variant in voice_info.get("variants", [])
|
||||||
]
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
async def async_get_tts_audio(
|
async def async_get_tts_audio(
|
||||||
self, message: str, language: str, options: dict[str, Any]
|
self, message: str, language: str, options: dict[str, Any]
|
||||||
) -> TtsAudioType:
|
) -> TtsAudioType:
|
||||||
"""Load TTS from Home Assistant Cloud."""
|
"""Load TTS from Home Assistant Cloud."""
|
||||||
gender: Gender | str | None = options.get(ATTR_GENDER)
|
|
||||||
gender = handle_deprecated_gender(self.hass, gender)
|
|
||||||
original_voice: str = options.get(
|
|
||||||
ATTR_VOICE,
|
|
||||||
self._voice if language == self._language else DEFAULT_VOICES[language],
|
|
||||||
)
|
|
||||||
voice = handle_deprecated_voice(self.hass, original_voice)
|
|
||||||
if voice not in TTS_VOICES[language]:
|
|
||||||
default_voice = DEFAULT_VOICES[language]
|
|
||||||
_LOGGER.debug(
|
|
||||||
"Unsupported voice %s detected, falling back to default %s for %s",
|
|
||||||
voice,
|
|
||||||
default_voice,
|
|
||||||
language,
|
|
||||||
)
|
|
||||||
voice = default_voice
|
|
||||||
# Process TTS
|
# Process TTS
|
||||||
try:
|
try:
|
||||||
data = await self.cloud.voice.process_tts(
|
data = await self.cloud.voice.process_tts(
|
||||||
text=message,
|
text=message,
|
||||||
language=language,
|
|
||||||
gender=gender,
|
|
||||||
voice=voice,
|
|
||||||
output=options[ATTR_AUDIO_OUTPUT],
|
output=options[ATTR_AUDIO_OUTPUT],
|
||||||
|
**_prepare_voice_args(
|
||||||
|
hass=self.hass,
|
||||||
|
language=language,
|
||||||
|
voice=options.get(
|
||||||
|
ATTR_VOICE,
|
||||||
|
self._voice
|
||||||
|
if language == self._language
|
||||||
|
else DEFAULT_VOICES[language],
|
||||||
|
),
|
||||||
|
gender=options.get(ATTR_GENDER),
|
||||||
|
),
|
||||||
)
|
)
|
||||||
except VoiceError as err:
|
except VoiceError as err:
|
||||||
_LOGGER.error("Voice error: %s", err)
|
_LOGGER.error("Voice error: %s", err)
|
||||||
@ -411,13 +467,38 @@ class CloudProvider(Provider):
|
|||||||
"""Return a list of supported voices for a language."""
|
"""Return a list of supported voices for a language."""
|
||||||
if not (voices := TTS_VOICES.get(language)):
|
if not (voices := TTS_VOICES.get(language)):
|
||||||
return None
|
return None
|
||||||
return [
|
|
||||||
|
result = []
|
||||||
|
|
||||||
|
for voice_id, voice_info in voices.items():
|
||||||
|
if isinstance(voice_info, str):
|
||||||
|
result.append(
|
||||||
Voice(
|
Voice(
|
||||||
voice,
|
voice_id,
|
||||||
voice_info["name"] if isinstance(voice_info, dict) else voice_info,
|
voice_info,
|
||||||
)
|
)
|
||||||
for voice, voice_info in voices.items()
|
)
|
||||||
|
continue
|
||||||
|
|
||||||
|
name = voice_info["name"]
|
||||||
|
|
||||||
|
result.append(
|
||||||
|
Voice(
|
||||||
|
voice_id,
|
||||||
|
name,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
result.extend(
|
||||||
|
[
|
||||||
|
Voice(
|
||||||
|
f"{voice_id}{VOICE_STYLE_SEPERATOR}{variant}",
|
||||||
|
f"{name} ({variant})",
|
||||||
|
)
|
||||||
|
for variant in voice_info.get("variants", [])
|
||||||
]
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def default_options(self) -> dict[str, str]:
|
def default_options(self) -> dict[str, str]:
|
||||||
@ -431,30 +512,22 @@ class CloudProvider(Provider):
|
|||||||
) -> TtsAudioType:
|
) -> TtsAudioType:
|
||||||
"""Load TTS from Home Assistant Cloud."""
|
"""Load TTS from Home Assistant Cloud."""
|
||||||
assert self.hass is not None
|
assert self.hass is not None
|
||||||
gender: Gender | str | None = options.get(ATTR_GENDER)
|
|
||||||
gender = handle_deprecated_gender(self.hass, gender)
|
|
||||||
original_voice: str = options.get(
|
|
||||||
ATTR_VOICE,
|
|
||||||
self._voice if language == self._language else DEFAULT_VOICES[language],
|
|
||||||
)
|
|
||||||
voice = handle_deprecated_voice(self.hass, original_voice)
|
|
||||||
if voice not in TTS_VOICES[language]:
|
|
||||||
default_voice = DEFAULT_VOICES[language]
|
|
||||||
_LOGGER.debug(
|
|
||||||
"Unsupported voice %s detected, falling back to default %s for %s",
|
|
||||||
voice,
|
|
||||||
default_voice,
|
|
||||||
language,
|
|
||||||
)
|
|
||||||
voice = default_voice
|
|
||||||
# Process TTS
|
# Process TTS
|
||||||
try:
|
try:
|
||||||
data = await self.cloud.voice.process_tts(
|
data = await self.cloud.voice.process_tts(
|
||||||
text=message,
|
text=message,
|
||||||
language=language,
|
|
||||||
gender=gender,
|
|
||||||
voice=voice,
|
|
||||||
output=options[ATTR_AUDIO_OUTPUT],
|
output=options[ATTR_AUDIO_OUTPUT],
|
||||||
|
**_prepare_voice_args(
|
||||||
|
hass=self.hass,
|
||||||
|
language=language,
|
||||||
|
voice=options.get(
|
||||||
|
ATTR_VOICE,
|
||||||
|
self._voice
|
||||||
|
if language == self._language
|
||||||
|
else DEFAULT_VOICES[language],
|
||||||
|
),
|
||||||
|
gender=options.get(ATTR_GENDER),
|
||||||
|
),
|
||||||
)
|
)
|
||||||
except VoiceError as err:
|
except VoiceError as err:
|
||||||
_LOGGER.error("Voice error: %s", err)
|
_LOGGER.error("Voice error: %s", err)
|
||||||
|
@ -4,7 +4,6 @@ from collections.abc import Callable, Coroutine
|
|||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
import datetime
|
import datetime
|
||||||
from http import HTTPStatus
|
from http import HTTPStatus
|
||||||
import json
|
|
||||||
import logging
|
import logging
|
||||||
from typing import Any
|
from typing import Any
|
||||||
from unittest.mock import AsyncMock, MagicMock, Mock, PropertyMock, patch
|
from unittest.mock import AsyncMock, MagicMock, Mock, PropertyMock, patch
|
||||||
@ -20,7 +19,6 @@ from hass_nabucasa.auth import (
|
|||||||
)
|
)
|
||||||
from hass_nabucasa.const import STATE_CONNECTED
|
from hass_nabucasa.const import STATE_CONNECTED
|
||||||
from hass_nabucasa.remote import CertificateStatus
|
from hass_nabucasa.remote import CertificateStatus
|
||||||
from hass_nabucasa.voice_data import TTS_VOICES
|
|
||||||
import pytest
|
import pytest
|
||||||
from syrupy.assertion import SnapshotAssertion
|
from syrupy.assertion import SnapshotAssertion
|
||||||
|
|
||||||
@ -31,6 +29,7 @@ from homeassistant.components.alexa import errors as alexa_errors
|
|||||||
from homeassistant.components.alexa.entities import LightCapabilities
|
from homeassistant.components.alexa.entities import LightCapabilities
|
||||||
from homeassistant.components.assist_pipeline.pipeline import STORAGE_KEY
|
from homeassistant.components.assist_pipeline.pipeline import STORAGE_KEY
|
||||||
from homeassistant.components.cloud.const import DEFAULT_EXPOSED_DOMAINS, DOMAIN
|
from homeassistant.components.cloud.const import DEFAULT_EXPOSED_DOMAINS, DOMAIN
|
||||||
|
from homeassistant.components.cloud.http_api import validate_language_voice
|
||||||
from homeassistant.components.google_assistant.helpers import GoogleEntity
|
from homeassistant.components.google_assistant.helpers import GoogleEntity
|
||||||
from homeassistant.components.homeassistant import exposed_entities
|
from homeassistant.components.homeassistant import exposed_entities
|
||||||
from homeassistant.components.websocket_api import ERR_INVALID_FORMAT
|
from homeassistant.components.websocket_api import ERR_INVALID_FORMAT
|
||||||
@ -1822,17 +1821,14 @@ async def test_tts_info(
|
|||||||
response = await client.receive_json()
|
response = await client.receive_json()
|
||||||
|
|
||||||
assert response["success"]
|
assert response["success"]
|
||||||
assert response["result"] == {
|
assert "languages" in response["result"]
|
||||||
"languages": json.loads(
|
assert all(len(lang) for lang in response["result"]["languages"])
|
||||||
json.dumps(
|
assert len(response["result"]["languages"]) > 300
|
||||||
[
|
assert (
|
||||||
(language, voice)
|
len([lang for lang in response["result"]["languages"] if "||" in lang[1]]) > 100
|
||||||
for language, voices in TTS_VOICES.items()
|
|
||||||
for voice in voices
|
|
||||||
]
|
|
||||||
)
|
)
|
||||||
)
|
for lang in response["result"]["languages"]:
|
||||||
}
|
assert validate_language_voice(lang[:2])
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
|
Loading…
x
Reference in New Issue
Block a user