mirror of
https://github.com/home-assistant/core.git
synced 2025-07-27 07:07:28 +00:00
Cloud: Add web socket API to pick default TTS language (#45064)
* Allow picking default TTS language * Fix test * Fix coroutine function * Improve test coverage * Remove stale import * Clean up hass Co-authored-by: Martin Hjelmare <marhje52@gmail.com>
This commit is contained in:
parent
4e71be852a
commit
82746616fa
@ -20,6 +20,8 @@ PREF_GOOGLE_LOCAL_WEBHOOK_ID = "google_local_webhook_id"
|
|||||||
PREF_USERNAME = "username"
|
PREF_USERNAME = "username"
|
||||||
PREF_ALEXA_DEFAULT_EXPOSE = "alexa_default_expose"
|
PREF_ALEXA_DEFAULT_EXPOSE = "alexa_default_expose"
|
||||||
PREF_GOOGLE_DEFAULT_EXPOSE = "google_default_expose"
|
PREF_GOOGLE_DEFAULT_EXPOSE = "google_default_expose"
|
||||||
|
PREF_TTS_DEFAULT_VOICE = "tts_default_voice"
|
||||||
|
DEFAULT_TTS_DEFAULT_VOICE = ("en-US", "female")
|
||||||
DEFAULT_DISABLE_2FA = False
|
DEFAULT_DISABLE_2FA = False
|
||||||
DEFAULT_ALEXA_REPORT_STATE = False
|
DEFAULT_ALEXA_REPORT_STATE = False
|
||||||
DEFAULT_GOOGLE_REPORT_STATE = False
|
DEFAULT_GOOGLE_REPORT_STATE = False
|
||||||
|
@ -8,6 +8,7 @@ import async_timeout
|
|||||||
import attr
|
import attr
|
||||||
from hass_nabucasa import Cloud, auth, thingtalk
|
from hass_nabucasa import Cloud, auth, thingtalk
|
||||||
from hass_nabucasa.const import STATE_DISCONNECTED
|
from hass_nabucasa.const import STATE_DISCONNECTED
|
||||||
|
from hass_nabucasa.voice import MAP_VOICE
|
||||||
import voluptuous as vol
|
import voluptuous as vol
|
||||||
|
|
||||||
from homeassistant.components import websocket_api
|
from homeassistant.components import websocket_api
|
||||||
@ -37,6 +38,7 @@ from .const import (
|
|||||||
PREF_GOOGLE_DEFAULT_EXPOSE,
|
PREF_GOOGLE_DEFAULT_EXPOSE,
|
||||||
PREF_GOOGLE_REPORT_STATE,
|
PREF_GOOGLE_REPORT_STATE,
|
||||||
PREF_GOOGLE_SECURE_DEVICES_PIN,
|
PREF_GOOGLE_SECURE_DEVICES_PIN,
|
||||||
|
PREF_TTS_DEFAULT_VOICE,
|
||||||
REQUEST_TIMEOUT,
|
REQUEST_TIMEOUT,
|
||||||
InvalidTrustedNetworks,
|
InvalidTrustedNetworks,
|
||||||
InvalidTrustedProxies,
|
InvalidTrustedProxies,
|
||||||
@ -115,6 +117,7 @@ async def async_setup(hass):
|
|||||||
async_register_command(alexa_sync)
|
async_register_command(alexa_sync)
|
||||||
|
|
||||||
async_register_command(thingtalk_convert)
|
async_register_command(thingtalk_convert)
|
||||||
|
async_register_command(tts_info)
|
||||||
|
|
||||||
hass.http.register_view(GoogleActionsSyncView)
|
hass.http.register_view(GoogleActionsSyncView)
|
||||||
hass.http.register_view(CloudLoginView)
|
hass.http.register_view(CloudLoginView)
|
||||||
@ -385,6 +388,9 @@ async def websocket_subscription(hass, connection, msg):
|
|||||||
vol.Optional(PREF_ALEXA_DEFAULT_EXPOSE): [str],
|
vol.Optional(PREF_ALEXA_DEFAULT_EXPOSE): [str],
|
||||||
vol.Optional(PREF_GOOGLE_DEFAULT_EXPOSE): [str],
|
vol.Optional(PREF_GOOGLE_DEFAULT_EXPOSE): [str],
|
||||||
vol.Optional(PREF_GOOGLE_SECURE_DEVICES_PIN): vol.Any(None, str),
|
vol.Optional(PREF_GOOGLE_SECURE_DEVICES_PIN): vol.Any(None, str),
|
||||||
|
vol.Optional(PREF_TTS_DEFAULT_VOICE): vol.All(
|
||||||
|
vol.Coerce(tuple), vol.In(MAP_VOICE)
|
||||||
|
),
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
async def websocket_update_prefs(hass, connection, msg):
|
async def websocket_update_prefs(hass, connection, msg):
|
||||||
@ -637,3 +643,11 @@ async def thingtalk_convert(hass, connection, msg):
|
|||||||
)
|
)
|
||||||
except thingtalk.ThingTalkConversionError as err:
|
except thingtalk.ThingTalkConversionError as err:
|
||||||
connection.send_error(msg["id"], ws_const.ERR_UNKNOWN_ERROR, str(err))
|
connection.send_error(msg["id"], ws_const.ERR_UNKNOWN_ERROR, str(err))
|
||||||
|
|
||||||
|
|
||||||
|
@websocket_api.websocket_command({"type": "cloud/tts/info"})
|
||||||
|
def tts_info(hass, connection, msg):
|
||||||
|
"""Fetch available tts info."""
|
||||||
|
connection.send_result(
|
||||||
|
msg["id"], {"languages": [(lang, gender.value) for lang, gender in MAP_VOICE]}
|
||||||
|
)
|
||||||
|
@ -12,6 +12,7 @@ from .const import (
|
|||||||
DEFAULT_ALEXA_REPORT_STATE,
|
DEFAULT_ALEXA_REPORT_STATE,
|
||||||
DEFAULT_EXPOSED_DOMAINS,
|
DEFAULT_EXPOSED_DOMAINS,
|
||||||
DEFAULT_GOOGLE_REPORT_STATE,
|
DEFAULT_GOOGLE_REPORT_STATE,
|
||||||
|
DEFAULT_TTS_DEFAULT_VOICE,
|
||||||
DOMAIN,
|
DOMAIN,
|
||||||
PREF_ALEXA_DEFAULT_EXPOSE,
|
PREF_ALEXA_DEFAULT_EXPOSE,
|
||||||
PREF_ALEXA_ENTITY_CONFIGS,
|
PREF_ALEXA_ENTITY_CONFIGS,
|
||||||
@ -30,6 +31,7 @@ from .const import (
|
|||||||
PREF_GOOGLE_SECURE_DEVICES_PIN,
|
PREF_GOOGLE_SECURE_DEVICES_PIN,
|
||||||
PREF_OVERRIDE_NAME,
|
PREF_OVERRIDE_NAME,
|
||||||
PREF_SHOULD_EXPOSE,
|
PREF_SHOULD_EXPOSE,
|
||||||
|
PREF_TTS_DEFAULT_VOICE,
|
||||||
PREF_USERNAME,
|
PREF_USERNAME,
|
||||||
InvalidTrustedNetworks,
|
InvalidTrustedNetworks,
|
||||||
InvalidTrustedProxies,
|
InvalidTrustedProxies,
|
||||||
@ -86,6 +88,7 @@ class CloudPreferences:
|
|||||||
google_report_state=UNDEFINED,
|
google_report_state=UNDEFINED,
|
||||||
alexa_default_expose=UNDEFINED,
|
alexa_default_expose=UNDEFINED,
|
||||||
google_default_expose=UNDEFINED,
|
google_default_expose=UNDEFINED,
|
||||||
|
tts_default_voice=UNDEFINED,
|
||||||
):
|
):
|
||||||
"""Update user preferences."""
|
"""Update user preferences."""
|
||||||
prefs = {**self._prefs}
|
prefs = {**self._prefs}
|
||||||
@ -103,6 +106,7 @@ class CloudPreferences:
|
|||||||
(PREF_GOOGLE_REPORT_STATE, google_report_state),
|
(PREF_GOOGLE_REPORT_STATE, google_report_state),
|
||||||
(PREF_ALEXA_DEFAULT_EXPOSE, alexa_default_expose),
|
(PREF_ALEXA_DEFAULT_EXPOSE, alexa_default_expose),
|
||||||
(PREF_GOOGLE_DEFAULT_EXPOSE, google_default_expose),
|
(PREF_GOOGLE_DEFAULT_EXPOSE, google_default_expose),
|
||||||
|
(PREF_TTS_DEFAULT_VOICE, tts_default_voice),
|
||||||
):
|
):
|
||||||
if value is not UNDEFINED:
|
if value is not UNDEFINED:
|
||||||
prefs[key] = value
|
prefs[key] = value
|
||||||
@ -203,6 +207,7 @@ class CloudPreferences:
|
|||||||
PREF_GOOGLE_ENTITY_CONFIGS: self.google_entity_configs,
|
PREF_GOOGLE_ENTITY_CONFIGS: self.google_entity_configs,
|
||||||
PREF_GOOGLE_REPORT_STATE: self.google_report_state,
|
PREF_GOOGLE_REPORT_STATE: self.google_report_state,
|
||||||
PREF_GOOGLE_SECURE_DEVICES_PIN: self.google_secure_devices_pin,
|
PREF_GOOGLE_SECURE_DEVICES_PIN: self.google_secure_devices_pin,
|
||||||
|
PREF_TTS_DEFAULT_VOICE: self.tts_default_voice,
|
||||||
}
|
}
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@ -279,6 +284,11 @@ class CloudPreferences:
|
|||||||
"""Return the published cloud webhooks."""
|
"""Return the published cloud webhooks."""
|
||||||
return self._prefs.get(PREF_CLOUDHOOKS, {})
|
return self._prefs.get(PREF_CLOUDHOOKS, {})
|
||||||
|
|
||||||
|
@property
|
||||||
|
def tts_default_voice(self):
|
||||||
|
"""Return the default TTS voice."""
|
||||||
|
return self._prefs.get(PREF_TTS_DEFAULT_VOICE, DEFAULT_TTS_DEFAULT_VOICE)
|
||||||
|
|
||||||
async def get_cloud_user(self) -> str:
|
async def get_cloud_user(self) -> str:
|
||||||
"""Return ID from Home Assistant Cloud system user."""
|
"""Return ID from Home Assistant Cloud system user."""
|
||||||
user = await self._load_cloud_user()
|
user = await self._load_cloud_user()
|
||||||
|
@ -12,13 +12,14 @@ CONF_GENDER = "gender"
|
|||||||
|
|
||||||
SUPPORT_LANGUAGES = list({key[0] for key in MAP_VOICE})
|
SUPPORT_LANGUAGES = list({key[0] for key in MAP_VOICE})
|
||||||
|
|
||||||
DEFAULT_LANG = "en-US"
|
|
||||||
DEFAULT_GENDER = "female"
|
|
||||||
|
|
||||||
|
|
||||||
def validate_lang(value):
|
def validate_lang(value):
|
||||||
"""Validate chosen gender or language."""
|
"""Validate chosen gender or language."""
|
||||||
lang = value[CONF_LANG]
|
lang = value.get(CONF_LANG)
|
||||||
|
|
||||||
|
if lang is None:
|
||||||
|
return value
|
||||||
|
|
||||||
gender = value.get(CONF_GENDER)
|
gender = value.get(CONF_GENDER)
|
||||||
|
|
||||||
if gender is None:
|
if gender is None:
|
||||||
@ -35,7 +36,7 @@ def validate_lang(value):
|
|||||||
PLATFORM_SCHEMA = vol.All(
|
PLATFORM_SCHEMA = vol.All(
|
||||||
PLATFORM_SCHEMA.extend(
|
PLATFORM_SCHEMA.extend(
|
||||||
{
|
{
|
||||||
vol.Optional(CONF_LANG, default=DEFAULT_LANG): str,
|
vol.Optional(CONF_LANG): str,
|
||||||
vol.Optional(CONF_GENDER): str,
|
vol.Optional(CONF_GENDER): str,
|
||||||
}
|
}
|
||||||
),
|
),
|
||||||
@ -48,8 +49,8 @@ async def async_get_engine(hass, config, discovery_info=None):
|
|||||||
cloud: Cloud = hass.data[DOMAIN]
|
cloud: Cloud = hass.data[DOMAIN]
|
||||||
|
|
||||||
if discovery_info is not None:
|
if discovery_info is not None:
|
||||||
language = DEFAULT_LANG
|
language = None
|
||||||
gender = DEFAULT_GENDER
|
gender = None
|
||||||
else:
|
else:
|
||||||
language = config[CONF_LANG]
|
language = config[CONF_LANG]
|
||||||
gender = config[CONF_GENDER]
|
gender = config[CONF_GENDER]
|
||||||
@ -67,6 +68,16 @@ class CloudProvider(Provider):
|
|||||||
self._language = language
|
self._language = language
|
||||||
self._gender = gender
|
self._gender = gender
|
||||||
|
|
||||||
|
if self._language is not None:
|
||||||
|
return
|
||||||
|
|
||||||
|
self._language, self._gender = cloud.client.prefs.tts_default_voice
|
||||||
|
cloud.client.prefs.async_listen_updates(self._sync_prefs)
|
||||||
|
|
||||||
|
async def _sync_prefs(self, prefs):
|
||||||
|
"""Sync preferences."""
|
||||||
|
self._language, self._gender = prefs.tts_default_voice
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def default_language(self):
|
def default_language(self):
|
||||||
"""Return the default language."""
|
"""Return the default language."""
|
||||||
|
@ -4,7 +4,7 @@ from ipaddress import ip_network
|
|||||||
from unittest.mock import AsyncMock, MagicMock, Mock, patch
|
from unittest.mock import AsyncMock, MagicMock, Mock, patch
|
||||||
|
|
||||||
import aiohttp
|
import aiohttp
|
||||||
from hass_nabucasa import thingtalk
|
from hass_nabucasa import thingtalk, voice
|
||||||
from hass_nabucasa.auth import Unauthenticated, UnknownError
|
from hass_nabucasa.auth import Unauthenticated, UnknownError
|
||||||
from hass_nabucasa.const import STATE_CONNECTED
|
from hass_nabucasa.const import STATE_CONNECTED
|
||||||
from jose import jwt
|
from jose import jwt
|
||||||
@ -361,6 +361,7 @@ async def test_websocket_status(
|
|||||||
"alexa_report_state": False,
|
"alexa_report_state": False,
|
||||||
"google_report_state": False,
|
"google_report_state": False,
|
||||||
"remote_enabled": False,
|
"remote_enabled": False,
|
||||||
|
"tts_default_voice": ["en-US", "female"],
|
||||||
},
|
},
|
||||||
"alexa_entities": {
|
"alexa_entities": {
|
||||||
"include_domains": [],
|
"include_domains": [],
|
||||||
@ -491,6 +492,7 @@ async def test_websocket_update_preferences(
|
|||||||
"google_secure_devices_pin": "1234",
|
"google_secure_devices_pin": "1234",
|
||||||
"google_default_expose": ["light", "switch"],
|
"google_default_expose": ["light", "switch"],
|
||||||
"alexa_default_expose": ["sensor", "media_player"],
|
"alexa_default_expose": ["sensor", "media_player"],
|
||||||
|
"tts_default_voice": ["en-GB", "male"],
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
response = await client.receive_json()
|
response = await client.receive_json()
|
||||||
@ -501,6 +503,7 @@ async def test_websocket_update_preferences(
|
|||||||
assert setup_api.google_secure_devices_pin == "1234"
|
assert setup_api.google_secure_devices_pin == "1234"
|
||||||
assert setup_api.google_default_expose == ["light", "switch"]
|
assert setup_api.google_default_expose == ["light", "switch"]
|
||||||
assert setup_api.alexa_default_expose == ["sensor", "media_player"]
|
assert setup_api.alexa_default_expose == ["sensor", "media_player"]
|
||||||
|
assert setup_api.tts_default_voice == ("en-GB", "male")
|
||||||
|
|
||||||
|
|
||||||
async def test_websocket_update_preferences_require_relink(
|
async def test_websocket_update_preferences_require_relink(
|
||||||
@ -975,3 +978,25 @@ async def test_thingtalk_convert_internal(hass, hass_ws_client, setup_api):
|
|||||||
assert not response["success"]
|
assert not response["success"]
|
||||||
assert response["error"]["code"] == "unknown_error"
|
assert response["error"]["code"] == "unknown_error"
|
||||||
assert response["error"]["message"] == "Did not understand"
|
assert response["error"]["message"] == "Did not understand"
|
||||||
|
|
||||||
|
|
||||||
|
async def test_tts_info(hass, hass_ws_client, setup_api):
|
||||||
|
"""Test that we can get TTS info."""
|
||||||
|
# Verify the format is as expected
|
||||||
|
assert voice.MAP_VOICE[("en-US", voice.Gender.FEMALE)] == "JennyNeural"
|
||||||
|
|
||||||
|
client = await hass_ws_client(hass)
|
||||||
|
|
||||||
|
with patch.dict(
|
||||||
|
"homeassistant.components.cloud.http_api.MAP_VOICE",
|
||||||
|
{
|
||||||
|
("en-US", voice.Gender.MALE): "GuyNeural",
|
||||||
|
("en-US", voice.Gender.FEMALE): "JennyNeural",
|
||||||
|
},
|
||||||
|
clear=True,
|
||||||
|
):
|
||||||
|
await client.send_json({"id": 5, "type": "cloud/tts/info"})
|
||||||
|
response = await client.receive_json()
|
||||||
|
|
||||||
|
assert response["success"]
|
||||||
|
assert response["result"] == {"languages": [["en-US", "male"], ["en-US", "female"]]}
|
||||||
|
@ -1,5 +1,22 @@
|
|||||||
"""Tests for cloud tts."""
|
"""Tests for cloud tts."""
|
||||||
from homeassistant.components.cloud import tts
|
from unittest.mock import Mock
|
||||||
|
|
||||||
|
from hass_nabucasa import voice
|
||||||
|
import pytest
|
||||||
|
import voluptuous as vol
|
||||||
|
|
||||||
|
from homeassistant.components.cloud import const, tts
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture()
|
||||||
|
def cloud_with_prefs(cloud_prefs):
|
||||||
|
"""Return a cloud mock with prefs."""
|
||||||
|
return Mock(client=Mock(prefs=cloud_prefs))
|
||||||
|
|
||||||
|
|
||||||
|
def test_default_exists():
|
||||||
|
"""Test our default language exists."""
|
||||||
|
assert const.DEFAULT_TTS_DEFAULT_VOICE in voice.MAP_VOICE
|
||||||
|
|
||||||
|
|
||||||
def test_schema():
|
def test_schema():
|
||||||
@ -9,7 +26,61 @@ def test_schema():
|
|||||||
processed = tts.PLATFORM_SCHEMA({"platform": "cloud", "language": "nl-NL"})
|
processed = tts.PLATFORM_SCHEMA({"platform": "cloud", "language": "nl-NL"})
|
||||||
assert processed["gender"] == "female"
|
assert processed["gender"] == "female"
|
||||||
|
|
||||||
|
with pytest.raises(vol.Invalid):
|
||||||
|
tts.PLATFORM_SCHEMA(
|
||||||
|
{"platform": "cloud", "language": "non-existing", "gender": "female"}
|
||||||
|
)
|
||||||
|
|
||||||
|
with pytest.raises(vol.Invalid):
|
||||||
|
tts.PLATFORM_SCHEMA(
|
||||||
|
{"platform": "cloud", "language": "nl-NL", "gender": "not-supported"}
|
||||||
|
)
|
||||||
|
|
||||||
# Should not raise
|
# Should not raise
|
||||||
processed = tts.PLATFORM_SCHEMA(
|
tts.PLATFORM_SCHEMA({"platform": "cloud", "language": "nl-NL", "gender": "female"})
|
||||||
{"platform": "cloud", "language": "nl-NL", "gender": "female"}
|
tts.PLATFORM_SCHEMA({"platform": "cloud"})
|
||||||
|
|
||||||
|
|
||||||
|
async def test_prefs_default_voice(hass, cloud_with_prefs, cloud_prefs):
|
||||||
|
"""Test cloud provider uses the preferences."""
|
||||||
|
assert cloud_prefs.tts_default_voice == ("en-US", "female")
|
||||||
|
|
||||||
|
provider_pref = await tts.async_get_engine(
|
||||||
|
Mock(data={const.DOMAIN: cloud_with_prefs}), None, {}
|
||||||
)
|
)
|
||||||
|
provider_conf = await tts.async_get_engine(
|
||||||
|
Mock(data={const.DOMAIN: cloud_with_prefs}),
|
||||||
|
{"language": "fr-FR", "gender": "female"},
|
||||||
|
None,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert provider_pref.default_language == "en-US"
|
||||||
|
assert provider_pref.default_options == {"gender": "female"}
|
||||||
|
assert provider_conf.default_language == "fr-FR"
|
||||||
|
assert provider_conf.default_options == {"gender": "female"}
|
||||||
|
|
||||||
|
await cloud_prefs.async_update(tts_default_voice=("nl-NL", "male"))
|
||||||
|
await hass.async_block_till_done()
|
||||||
|
|
||||||
|
assert provider_pref.default_language == "nl-NL"
|
||||||
|
assert provider_pref.default_options == {"gender": "male"}
|
||||||
|
assert provider_conf.default_language == "fr-FR"
|
||||||
|
assert provider_conf.default_options == {"gender": "female"}
|
||||||
|
|
||||||
|
|
||||||
|
async def test_provider_properties(cloud_with_prefs):
|
||||||
|
"""Test cloud provider."""
|
||||||
|
provider = await tts.async_get_engine(
|
||||||
|
Mock(data={const.DOMAIN: cloud_with_prefs}), None, {}
|
||||||
|
)
|
||||||
|
assert provider.supported_options == ["gender"]
|
||||||
|
assert "nl-NL" in provider.supported_languages
|
||||||
|
|
||||||
|
|
||||||
|
async def test_get_tts_audio(cloud_with_prefs):
|
||||||
|
"""Test cloud provider."""
|
||||||
|
provider = await tts.async_get_engine(
|
||||||
|
Mock(data={const.DOMAIN: cloud_with_prefs}), None, {}
|
||||||
|
)
|
||||||
|
assert provider.supported_options == ["gender"]
|
||||||
|
assert "nl-NL" in provider.supported_languages
|
||||||
|
Loading…
x
Reference in New Issue
Block a user