From 4728fa8da64b215ac6fd5b8b3ef05893309f032a Mon Sep 17 00:00:00 2001 From: andrey-git Date: Tue, 27 Dec 2016 18:01:22 +0200 Subject: [PATCH] Allow to specify TTS language in the service call. (#5047) * Allow to specify TTS language in the service call. * Allow to specify TTS language in the service call. * Respect 79 char limit * Fix "Too many blank lines" * Fix "Too many blank lines" * Fix "Too many blank lines" * Change language to be optional parameter of *get_tts_audio * Change language to be optional parameter of *get_tts_audio * Respect 79 char limit * Don't pass "None * Use default of "None" for TTS language * Use default of "None" for TTS language * Don't pass "None" * Change TTS cache key to be hash_lang_engine * Change language from demo to en * Fix wrong replace --- homeassistant/components/tts/__init__.py | 46 +++++---- homeassistant/components/tts/demo.py | 6 +- homeassistant/components/tts/google.py | 9 +- homeassistant/components/tts/services.yaml | 4 + homeassistant/components/tts/voicerss.py | 7 +- tests/components/tts/test_google.py | 32 +++++- tests/components/tts/test_init.py | 108 ++++++++++++++++++--- tests/components/tts/test_voicerss.py | 32 +++++- 8 files changed, 206 insertions(+), 38 deletions(-) diff --git a/homeassistant/components/tts/__init__.py b/homeassistant/components/tts/__init__.py index bd19de52a98..01d0a6a15e3 100644 --- a/homeassistant/components/tts/__init__.py +++ b/homeassistant/components/tts/__init__.py @@ -5,8 +5,9 @@ For more details about this component, please refer to the documentation at https://home-assistant.io/components/tts/ """ import asyncio -import logging +import functools import hashlib +import logging import mimetypes import os import re @@ -48,8 +49,10 @@ SERVICE_CLEAR_CACHE = 'clear_cache' ATTR_MESSAGE = 'message' ATTR_CACHE = 'cache' +ATTR_LANGUAGE = 'language' -_RE_VOICE_FILE = re.compile(r"([a-f0-9]{40})_([a-z]+)\.[a-z0-9]{3,4}") +_RE_VOICE_FILE = re.compile(r"([a-f0-9]{40})_([^_]+)_([a-z]+)\.[a-z0-9]{3,4}") +KEY_PATTERN = '{}_{}_{}' PLATFORM_SCHEMA = cv.PLATFORM_SCHEMA.extend({ vol.Optional(CONF_CACHE, default=DEFAULT_CACHE): cv.boolean, @@ -63,6 +66,7 @@ SCHEMA_SERVICE_SAY = vol.Schema({ vol.Required(ATTR_MESSAGE): cv.string, vol.Optional(ATTR_ENTITY_ID): cv.entity_ids, vol.Optional(ATTR_CACHE): cv.boolean, + vol.Optional(ATTR_LANGUAGE): cv.string }) SCHEMA_SERVICE_CLEAR_CACHE = vol.Schema({}) @@ -121,10 +125,11 @@ def async_setup(hass, config): entity_ids = service.data.get(ATTR_ENTITY_ID) message = service.data.get(ATTR_MESSAGE) cache = service.data.get(ATTR_CACHE) + language = service.data.get(ATTR_LANGUAGE) try: url = yield from tts.async_get_url( - p_type, message, cache=cache) + p_type, message, cache=cache, language=language) except HomeAssistantError as err: _LOGGER.error("Error on init tts: %s", err) return @@ -207,7 +212,8 @@ class SpeechManager(object): for file_data in folder_data: record = _RE_VOICE_FILE.match(file_data) if record: - key = "{}_{}".format(record.group(1), record.group(2)) + key = KEY_PATTERN.format( + record.group(1), record.group(2), record.group(3)) cache[key.lower()] = file_data.lower() return cache @@ -241,17 +247,19 @@ class SpeechManager(object): def async_register_engine(self, engine, provider, config): """Register a TTS provider.""" provider.hass = self.hass - provider.language = config.get(CONF_LANG) + if CONF_LANG in config: + provider.language = config.get(CONF_LANG) self.providers[engine] = provider @asyncio.coroutine - def async_get_url(self, engine, message, cache=None): + def async_get_url(self, engine, message, cache=None, language=None): """Get URL for play message. This method is a coroutine. """ msg_hash = hashlib.sha1(bytes(message, 'utf-8')).hexdigest() - key = ("{}_{}".format(msg_hash, engine)).lower() + language_key = language or self.providers[engine].language + key = KEY_PATTERN.format(msg_hash, language_key, engine).lower() use_cache = cache if cache is not None else self.use_cache # is speech allready in memory @@ -260,23 +268,24 @@ class SpeechManager(object): # is file store in file cache elif use_cache and key in self.file_cache: filename = self.file_cache[key] - self.hass.async_add_job(self.async_file_to_mem(engine, key)) + self.hass.async_add_job(self.async_file_to_mem(key)) # load speech from provider into memory else: filename = yield from self.async_get_tts_audio( - engine, key, message, use_cache) + engine, key, message, use_cache, language) return "{}/api/tts_proxy/{}".format( self.hass.config.api.base_url, filename) @asyncio.coroutine - def async_get_tts_audio(self, engine, key, message, cache): + def async_get_tts_audio(self, engine, key, message, cache, language): """Receive TTS and store for view in cache. This method is a coroutine. """ provider = self.providers[engine] - extension, data = yield from provider.async_get_tts_audio(message) + extension, data = yield from provider.async_get_tts_audio( + message, language) if data is None or extension is None: raise HomeAssistantError( @@ -314,7 +323,7 @@ class SpeechManager(object): _LOGGER.error("Can't write %s", filename) @asyncio.coroutine - def async_file_to_mem(self, engine, key): + def async_file_to_mem(self, key): """Load voice from file cache into memory. This method is a coroutine. @@ -362,13 +371,13 @@ class SpeechManager(object): if not record: raise HomeAssistantError("Wrong tts file format!") - key = "{}_{}".format(record.group(1), record.group(2)) + key = KEY_PATTERN.format( + record.group(1), record.group(2), record.group(3)) if key not in self.mem_cache: if key not in self.file_cache: raise HomeAssistantError("%s not in cache!", key) - engine = record.group(2) - yield from self.async_file_to_mem(engine, key) + yield from self.async_file_to_mem(key) content, _ = mimetypes.guess_type(filename) return (content, self.mem_cache[key][MEM_CACHE_VOICE]) @@ -380,11 +389,11 @@ class Provider(object): hass = None language = None - def get_tts_audio(self, message): + def get_tts_audio(self, message, language=None): """Load tts audio file from provider.""" raise NotImplementedError() - def async_get_tts_audio(self, message): + def async_get_tts_audio(self, message, language=None): """Load tts audio file from provider. Return a tuple of file extension and data as bytes. @@ -392,7 +401,8 @@ class Provider(object): This method must be run in the event loop and returns a coroutine. """ return self.hass.loop.run_in_executor( - None, self.get_tts_audio, message) + None, + functools.partial(self.get_tts_audio, message, language=language)) class TextToSpeechView(HomeAssistantView): diff --git a/homeassistant/components/tts/demo.py b/homeassistant/components/tts/demo.py index a63bd6373ea..68d49d58f78 100644 --- a/homeassistant/components/tts/demo.py +++ b/homeassistant/components/tts/demo.py @@ -17,7 +17,11 @@ def get_engine(hass, config): class DemoProvider(Provider): """Demo speech api provider.""" - def get_tts_audio(self, message): + def __init__(self): + """Initialize demo provider for TTS.""" + self.language = 'en' + + def get_tts_audio(self, message, language=None): """Load TTS from demo.""" filename = os.path.join(os.path.dirname(__file__), "demo.mp3") try: diff --git a/homeassistant/components/tts/google.py b/homeassistant/components/tts/google.py index 49d53961062..e1bb4e5e4e5 100644 --- a/homeassistant/components/tts/google.py +++ b/homeassistant/components/tts/google.py @@ -59,7 +59,7 @@ class GoogleProvider(Provider): } @asyncio.coroutine - def async_get_tts_audio(self, message): + def async_get_tts_audio(self, message, language=None): """Load TTS from google.""" from gtts_token import gtts_token @@ -67,6 +67,11 @@ class GoogleProvider(Provider): websession = async_get_clientsession(self.hass) message_parts = self._split_message_to_parts(message) + # If language is not specified or is not supported - use the language + # from the config. + if language not in SUPPORT_LANGUAGES: + language = self.language + data = b'' for idx, part in enumerate(message_parts): part_token = yield from self.hass.loop.run_in_executor( @@ -74,7 +79,7 @@ class GoogleProvider(Provider): url_param = { 'ie': 'UTF-8', - 'tl': self.language, + 'tl': language, 'q': yarl.quote(part), 'tk': part_token, 'total': len(message_parts), diff --git a/homeassistant/components/tts/services.yaml b/homeassistant/components/tts/services.yaml index 5cb146950b4..b44ef6ac66c 100644 --- a/homeassistant/components/tts/services.yaml +++ b/homeassistant/components/tts/services.yaml @@ -14,5 +14,9 @@ say: description: Control file cache of this message. example: 'true' + language: + description: Language to use for speech generation. + example: 'ru' + clear_cache: description: Remove cache files and RAM cache. diff --git a/homeassistant/components/tts/voicerss.py b/homeassistant/components/tts/voicerss.py index fdbe8a8d806..688ae7f6e25 100644 --- a/homeassistant/components/tts/voicerss.py +++ b/homeassistant/components/tts/voicerss.py @@ -103,13 +103,18 @@ class VoiceRSSProvider(Provider): } @asyncio.coroutine - def async_get_tts_audio(self, message): + def async_get_tts_audio(self, message, language=None): """Load TTS from voicerss.""" websession = async_get_clientsession(self.hass) form_data = self.form_data.copy() form_data['src'] = message + # If language is specified and supported - use it instead of the + # language in the config. + if language in SUPPORT_LANGUAGES: + form_data['hl'] = language + request = None try: with async_timeout.timeout(10, loop=self.hass.loop): diff --git a/tests/components/tts/test_google.py b/tests/components/tts/test_google.py index 623a96f1dfb..3483a4830fa 100644 --- a/tests/components/tts/test_google.py +++ b/tests/components/tts/test_google.py @@ -80,8 +80,8 @@ class TestTTSGooglePlatform(object): @patch('gtts_token.gtts_token.Token.calculate_token', autospec=True, return_value=5) - def test_service_say_german(self, mock_calculate, aioclient_mock): - """Test service call say with german code.""" + def test_service_say_german_config(self, mock_calculate, aioclient_mock): + """Test service call say with german code in the config.""" calls = mock_service(self.hass, DOMAIN_MP, SERVICE_PLAY_MEDIA) self.url_param['tl'] = 'de' @@ -106,6 +106,34 @@ class TestTTSGooglePlatform(object): assert len(calls) == 1 assert len(aioclient_mock.mock_calls) == 1 + @patch('gtts_token.gtts_token.Token.calculate_token', autospec=True, + return_value=5) + def test_service_say_german_service(self, mock_calculate, aioclient_mock): + """Test service call say with german code in the service.""" + calls = mock_service(self.hass, DOMAIN_MP, SERVICE_PLAY_MEDIA) + + self.url_param['tl'] = 'de' + aioclient_mock.get( + self.url, params=self.url_param, status=200, content=b'test') + + config = { + tts.DOMAIN: { + 'platform': 'google', + } + } + + with assert_setup_component(1, tts.DOMAIN): + setup_component(self.hass, tts.DOMAIN, config) + + self.hass.services.call(tts.DOMAIN, 'google_say', { + tts.ATTR_MESSAGE: "I person is on front of your door.", + tts.ATTR_LANGUAGE: "de" + }) + self.hass.block_till_done() + + assert len(calls) == 1 + assert len(aioclient_mock.mock_calls) == 1 + @patch('gtts_token.gtts_token.Token.calculate_token', autospec=True, return_value=5) def test_service_say_error(self, mock_calculate, aioclient_mock): diff --git a/tests/components/tts/test_init.py b/tests/components/tts/test_init.py index fccd9d66bd7..55381395313 100644 --- a/tests/components/tts/test_init.py +++ b/tests/components/tts/test_init.py @@ -82,11 +82,69 @@ class TestTTS(object): assert calls[0].data[ATTR_MEDIA_CONTENT_TYPE] == MEDIA_TYPE_MUSIC assert calls[0].data[ATTR_MEDIA_CONTENT_ID].find( "/api/tts_proxy/265944c108cbb00b2a621be5930513e03a0bb2cd" - "_demo.mp3") \ + "_en_demo.mp3") \ != -1 assert os.path.isfile(os.path.join( self.default_tts_cache, - "265944c108cbb00b2a621be5930513e03a0bb2cd_demo.mp3")) + "265944c108cbb00b2a621be5930513e03a0bb2cd_en_demo.mp3")) + + def test_setup_component_and_test_service_with_config_language(self): + """Setup the demo platform and call service.""" + calls = mock_service(self.hass, DOMAIN_MP, SERVICE_PLAY_MEDIA) + + config = { + tts.DOMAIN: { + 'platform': 'demo', + 'language': 'lang' + } + } + + with assert_setup_component(1, tts.DOMAIN): + setup_component(self.hass, tts.DOMAIN, config) + + self.hass.services.call(tts.DOMAIN, 'demo_say', { + tts.ATTR_MESSAGE: "I person is on front of your door.", + }) + self.hass.block_till_done() + + assert len(calls) == 1 + assert calls[0].data[ATTR_MEDIA_CONTENT_TYPE] == MEDIA_TYPE_MUSIC + assert calls[0].data[ATTR_MEDIA_CONTENT_ID].find( + "/api/tts_proxy/265944c108cbb00b2a621be5930513e03a0bb2cd" + "_lang_demo.mp3") \ + != -1 + assert os.path.isfile(os.path.join( + self.default_tts_cache, + "265944c108cbb00b2a621be5930513e03a0bb2cd_lang_demo.mp3")) + + def test_setup_component_and_test_service_with_service_language(self): + """Setup the demo platform and call service.""" + calls = mock_service(self.hass, DOMAIN_MP, SERVICE_PLAY_MEDIA) + + config = { + tts.DOMAIN: { + 'platform': 'demo', + } + } + + with assert_setup_component(1, tts.DOMAIN): + setup_component(self.hass, tts.DOMAIN, config) + + self.hass.services.call(tts.DOMAIN, 'demo_say', { + tts.ATTR_MESSAGE: "I person is on front of your door.", + tts.ATTR_LANGUAGE: "lang", + }) + self.hass.block_till_done() + + assert len(calls) == 1 + assert calls[0].data[ATTR_MEDIA_CONTENT_TYPE] == MEDIA_TYPE_MUSIC + assert calls[0].data[ATTR_MEDIA_CONTENT_ID].find( + "/api/tts_proxy/265944c108cbb00b2a621be5930513e03a0bb2cd" + "_lang_demo.mp3") \ + != -1 + assert os.path.isfile(os.path.join( + self.default_tts_cache, + "265944c108cbb00b2a621be5930513e03a0bb2cd_lang_demo.mp3")) def test_setup_component_and_test_service_clear_cache(self): """Setup the demo platform and call service clear cache.""" @@ -109,14 +167,14 @@ class TestTTS(object): assert len(calls) == 1 assert os.path.isfile(os.path.join( self.default_tts_cache, - "265944c108cbb00b2a621be5930513e03a0bb2cd_demo.mp3")) + "265944c108cbb00b2a621be5930513e03a0bb2cd_en_demo.mp3")) self.hass.services.call(tts.DOMAIN, tts.SERVICE_CLEAR_CACHE, {}) self.hass.block_till_done() assert not os.path.isfile(os.path.join( self.default_tts_cache, - "265944c108cbb00b2a621be5930513e03a0bb2cd_demo.mp3")) + "265944c108cbb00b2a621be5930513e03a0bb2cd_en_demo.mp3")) def test_setup_component_and_test_service_with_receive_voice(self): """Setup the demo platform and call service and receive voice.""" @@ -144,6 +202,32 @@ class TestTTS(object): assert req.status_code == 200 assert req.content == demo_data + def test_setup_component_and_test_service_with_receive_voice_german(self): + """Setup the demo platform and call service and receive voice.""" + calls = mock_service(self.hass, DOMAIN_MP, SERVICE_PLAY_MEDIA) + + config = { + tts.DOMAIN: { + 'platform': 'demo', + } + } + + with assert_setup_component(1, tts.DOMAIN): + setup_component(self.hass, tts.DOMAIN, config) + + self.hass.start() + + self.hass.services.call(tts.DOMAIN, 'demo_say', { + tts.ATTR_MESSAGE: "I person is on front of your door.", + }) + self.hass.block_till_done() + + assert len(calls) == 1 + req = requests.get(calls[0].data[ATTR_MEDIA_CONTENT_ID]) + _, demo_data = self.demo_provider.get_tts_audio("bla", "de") + assert req.status_code == 200 + assert req.content == demo_data + def test_setup_component_and_web_view_wrong_file(self): """Setup the demo platform and receive wrong file from web.""" config = { @@ -158,7 +242,7 @@ class TestTTS(object): self.hass.start() url = ("{}/api/tts_proxy/265944c108cbb00b2a621be5930513e03a0bb2cd" - "_demo.mp3").format(self.hass.config.api.base_url) + "_en_demo.mp3").format(self.hass.config.api.base_url) req = requests.get(url) assert req.status_code == 404 @@ -177,7 +261,7 @@ class TestTTS(object): self.hass.start() url = ("{}/api/tts_proxy/265944dsk32c1b2a621be5930510bb2cd" - "_demo.mp3").format(self.hass.config.api.base_url) + "_en_demo.mp3").format(self.hass.config.api.base_url) req = requests.get(url) assert req.status_code == 404 @@ -204,7 +288,7 @@ class TestTTS(object): assert len(calls) == 1 assert not os.path.isfile(os.path.join( self.default_tts_cache, - "265944c108cbb00b2a621be5930513e03a0bb2cd_demo.mp3")) + "265944c108cbb00b2a621be5930513e03a0bb2cd_en_demo.mp3")) def test_setup_component_test_with_cache_call_service_without_cache(self): """Setup demo platform with cache and call service without cache.""" @@ -229,7 +313,7 @@ class TestTTS(object): assert len(calls) == 1 assert not os.path.isfile(os.path.join( self.default_tts_cache, - "265944c108cbb00b2a621be5930513e03a0bb2cd_demo.mp3")) + "265944c108cbb00b2a621be5930513e03a0bb2cd_en_demo.mp3")) def test_setup_component_test_with_cache_dir(self): """Setup demo platform with cache and call service without cache.""" @@ -238,7 +322,7 @@ class TestTTS(object): _, demo_data = self.demo_provider.get_tts_audio("bla") cache_file = os.path.join( self.default_tts_cache, - "265944c108cbb00b2a621be5930513e03a0bb2cd_demo.mp3") + "265944c108cbb00b2a621be5930513e03a0bb2cd_en_demo.mp3") os.mkdir(self.default_tts_cache) with open(cache_file, "wb") as voice_file: @@ -264,7 +348,7 @@ class TestTTS(object): assert len(calls) == 1 assert calls[0].data[ATTR_MEDIA_CONTENT_ID].find( "/api/tts_proxy/265944c108cbb00b2a621be5930513e03a0bb2cd" - "_demo.mp3") \ + "_en_demo.mp3") \ != -1 @patch('homeassistant.components.tts.demo.DemoProvider.get_tts_audio', @@ -294,7 +378,7 @@ class TestTTS(object): _, demo_data = self.demo_provider.get_tts_audio("bla") cache_file = os.path.join( self.default_tts_cache, - "265944c108cbb00b2a621be5930513e03a0bb2cd_demo.mp3") + "265944c108cbb00b2a621be5930513e03a0bb2cd_en_demo.mp3") os.mkdir(self.default_tts_cache) with open(cache_file, "wb") as voice_file: @@ -313,7 +397,7 @@ class TestTTS(object): self.hass.start() url = ("{}/api/tts_proxy/265944c108cbb00b2a621be5930513e03a0bb2cd" - "_demo.mp3").format(self.hass.config.api.base_url) + "_en_demo.mp3").format(self.hass.config.api.base_url) req = requests.get(url) assert req.status_code == 200 diff --git a/tests/components/tts/test_voicerss.py b/tests/components/tts/test_voicerss.py index ea1263b189e..b8f73487831 100644 --- a/tests/components/tts/test_voicerss.py +++ b/tests/components/tts/test_voicerss.py @@ -86,8 +86,8 @@ class TestTTSVoiceRSSPlatform(object): assert aioclient_mock.mock_calls[0][2] == self.form_data assert calls[0].data[ATTR_MEDIA_CONTENT_ID].find(".mp3") != -1 - def test_service_say_german(self, aioclient_mock): - """Test service call say with german code.""" + def test_service_say_german_config(self, aioclient_mock): + """Test service call say with german code in the config.""" calls = mock_service(self.hass, DOMAIN_MP, SERVICE_PLAY_MEDIA) self.form_data['hl'] = 'de-de' @@ -114,6 +114,34 @@ class TestTTSVoiceRSSPlatform(object): assert len(aioclient_mock.mock_calls) == 1 assert aioclient_mock.mock_calls[0][2] == self.form_data + def test_service_say_german_service(self, aioclient_mock): + """Test service call say with german code in the service.""" + calls = mock_service(self.hass, DOMAIN_MP, SERVICE_PLAY_MEDIA) + + self.form_data['hl'] = 'de-de' + aioclient_mock.post( + self.url, data=self.form_data, status=200, content=b'test') + + config = { + tts.DOMAIN: { + 'platform': 'voicerss', + 'api_key': '1234567xx', + } + } + + with assert_setup_component(1, tts.DOMAIN): + setup_component(self.hass, tts.DOMAIN, config) + + self.hass.services.call(tts.DOMAIN, 'voicerss_say', { + tts.ATTR_MESSAGE: "I person is on front of your door.", + tts.ATTR_LANGUAGE: "de-de" + }) + self.hass.block_till_done() + + assert len(calls) == 1 + assert len(aioclient_mock.mock_calls) == 1 + assert aioclient_mock.mock_calls[0][2] == self.form_data + def test_service_say_error(self, aioclient_mock): """Test service call say with http response 400.""" calls = mock_service(self.hass, DOMAIN_MP, SERVICE_PLAY_MEDIA)