diff --git a/homeassistant/components/tts/__init__.py b/homeassistant/components/tts/__init__.py index 5ee92747196..eda0fa27f53 100644 --- a/homeassistant/components/tts/__init__.py +++ b/homeassistant/components/tts/__init__.py @@ -5,6 +5,8 @@ For more details about this component, please refer to the documentation at https://home-assistant.io/components/tts/ """ import asyncio +import ctypes +import functools as ft import hashlib import logging import mimetypes @@ -49,9 +51,11 @@ SERVICE_CLEAR_CACHE = 'clear_cache' ATTR_MESSAGE = 'message' ATTR_CACHE = 'cache' ATTR_LANGUAGE = 'language' +ATTR_OPTIONS = 'options' -_RE_VOICE_FILE = re.compile(r"([a-f0-9]{40})_([^_]+)_([a-z]+)\.[a-z0-9]{3,4}") -KEY_PATTERN = '{}_{}_{}' +_RE_VOICE_FILE = re.compile( + r"([a-f0-9]{40})_([^_]+)_([^_]+)_([a-z]+)\.[a-z0-9]{3,4}") +KEY_PATTERN = '{0}_{1}_{2}_{3}' PLATFORM_SCHEMA = cv.PLATFORM_SCHEMA.extend({ vol.Optional(CONF_CACHE, default=DEFAULT_CACHE): cv.boolean, @@ -60,12 +64,12 @@ PLATFORM_SCHEMA = cv.PLATFORM_SCHEMA.extend({ vol.All(vol.Coerce(int), vol.Range(min=60, max=57600)), }) - SCHEMA_SERVICE_SAY = vol.Schema({ vol.Required(ATTR_MESSAGE): cv.string, vol.Optional(ATTR_ENTITY_ID): cv.entity_ids, vol.Optional(ATTR_CACHE): cv.boolean, - vol.Optional(ATTR_LANGUAGE): cv.string + vol.Optional(ATTR_LANGUAGE): cv.string, + vol.Optional(ATTR_OPTIONS): dict, }) SCHEMA_SERVICE_CLEAR_CACHE = vol.Schema({}) @@ -125,10 +129,13 @@ def async_setup(hass, config): message = service.data.get(ATTR_MESSAGE) cache = service.data.get(ATTR_CACHE) language = service.data.get(ATTR_LANGUAGE) + options = service.data.get(ATTR_OPTIONS) try: url = yield from tts.async_get_url( - p_type, message, cache=cache, language=language) + p_type, message, cache=cache, language=language, + options=options + ) except HomeAssistantError as err: _LOGGER.error("Error on init tts: %s", err) return @@ -212,7 +219,9 @@ class SpeechManager(object): record = _RE_VOICE_FILE.match(file_data) if record: key = KEY_PATTERN.format( - record.group(1), record.group(2), record.group(3)) + record.group(1), record.group(2), record.group(3), + record.group(4) + ) cache[key.lower()] = file_data.lower() return cache @@ -249,22 +258,37 @@ class SpeechManager(object): self.providers[engine] = provider @asyncio.coroutine - def async_get_url(self, engine, message, cache=None, language=None): + def async_get_url(self, engine, message, cache=None, language=None, + options=None): """Get URL for play message. This method is a coroutine. """ provider = self.providers[engine] + msg_hash = hashlib.sha1(bytes(message, 'utf-8')).hexdigest() + use_cache = cache if cache is not None else self.use_cache + # languages language = language or provider.default_language if language is None or \ language not in provider.supported_languages: raise HomeAssistantError("Not supported language {0}".format( language)) - msg_hash = hashlib.sha1(bytes(message, 'utf-8')).hexdigest() - key = KEY_PATTERN.format(msg_hash, language, engine).lower() - use_cache = cache if cache is not None else self.use_cache + # options + options = options or provider.default_options + if options is not None: + invalid_opts = [opt_name for opt_name in options.keys() + if opt_name not in provider.supported_options] + if invalid_opts: + raise HomeAssistantError( + "Invalid options found: %s", invalid_opts) + options_key = ctypes.c_size_t(hash(frozenset(options))).value + else: + options_key = '-' + + key = KEY_PATTERN.format( + msg_hash, language, options_key, engine).lower() # is speech allready in memory if key in self.mem_cache: @@ -276,20 +300,21 @@ class SpeechManager(object): # load speech from provider into memory else: filename = yield from self.async_get_tts_audio( - engine, key, message, use_cache, language) + engine, key, message, use_cache, language, options) return "{}/api/tts_proxy/{}".format( self.hass.config.api.base_url, filename) @asyncio.coroutine - def async_get_tts_audio(self, engine, key, message, cache, language): + def async_get_tts_audio(self, engine, key, message, cache, language, + options): """Receive TTS and store for view in cache. This method is a coroutine. """ provider = self.providers[engine] extension, data = yield from provider.async_get_tts_audio( - message, language) + message, language, options) if data is None or extension is None: raise HomeAssistantError( @@ -377,7 +402,7 @@ class SpeechManager(object): raise HomeAssistantError("Wrong tts file format!") key = KEY_PATTERN.format( - record.group(1), record.group(2), record.group(3)) + record.group(1), record.group(2), record.group(3), record.group(4)) if key not in self.mem_cache: if key not in self.file_cache: @@ -403,11 +428,21 @@ class Provider(object): """List of supported languages.""" return None - def get_tts_audio(self, message, language): + @property + def supported_options(self): + """List of supported options like voice, emotionen.""" + return None + + @property + def default_options(self): + """Dict include default options.""" + return None + + def get_tts_audio(self, message, language, options=None): """Load tts audio file from provider.""" raise NotImplementedError() - def async_get_tts_audio(self, message, language): + def async_get_tts_audio(self, message, language, options=None): """Load tts audio file from provider. Return a tuple of file extension and data as bytes. @@ -415,7 +450,8 @@ class Provider(object): This method must be run in the event loop and returns a coroutine. """ return self.hass.loop.run_in_executor( - None, self.get_tts_audio, message, language) + None, ft.partial( + self.get_tts_audio, message, language, options=options)) class TextToSpeechView(HomeAssistantView): diff --git a/homeassistant/components/tts/demo.py b/homeassistant/components/tts/demo.py index 88afa0643f2..95362b49db9 100644 --- a/homeassistant/components/tts/demo.py +++ b/homeassistant/components/tts/demo.py @@ -43,7 +43,12 @@ class DemoProvider(Provider): """List of supported languages.""" return SUPPORT_LANGUAGES - def get_tts_audio(self, message, language): + @property + def supported_options(self): + """List of supported options like voice, emotionen.""" + return ['voice', 'age'] + + def get_tts_audio(self, message, language, options=None): """Load TTS from demo.""" filename = os.path.join(os.path.dirname(__file__), "demo.mp3") try: diff --git a/homeassistant/components/tts/google.py b/homeassistant/components/tts/google.py index 10ce3de6c6b..32c9663eedc 100644 --- a/homeassistant/components/tts/google.py +++ b/homeassistant/components/tts/google.py @@ -70,7 +70,7 @@ class GoogleProvider(Provider): return SUPPORT_LANGUAGES @asyncio.coroutine - def async_get_tts_audio(self, message, language): + def async_get_tts_audio(self, message, language, options=None): """Load TTS from google.""" from gtts_token import gtts_token diff --git a/homeassistant/components/tts/picotts.py b/homeassistant/components/tts/picotts.py index 3cc133864b6..49addd9b177 100644 --- a/homeassistant/components/tts/picotts.py +++ b/homeassistant/components/tts/picotts.py @@ -49,7 +49,7 @@ class PicoProvider(Provider): """List of supported languages.""" return SUPPORT_LANGUAGES - def get_tts_audio(self, message, language): + def get_tts_audio(self, message, language, options=None): """Load TTS using pico2wave.""" with tempfile.NamedTemporaryFile(suffix='.wav', delete=False) as tmpf: fname = tmpf.name diff --git a/homeassistant/components/tts/voicerss.py b/homeassistant/components/tts/voicerss.py index f7a97a354f0..b0c74d1de30 100644 --- a/homeassistant/components/tts/voicerss.py +++ b/homeassistant/components/tts/voicerss.py @@ -114,7 +114,7 @@ class VoiceRSSProvider(Provider): return SUPPORT_LANGUAGES @asyncio.coroutine - def async_get_tts_audio(self, message, language): + def async_get_tts_audio(self, message, language, options=None): """Load TTS from VoiceRSS.""" websession = async_get_clientsession(self.hass) form_data = self._form_data.copy() diff --git a/homeassistant/components/tts/yandextts.py b/homeassistant/components/tts/yandextts.py index d5825ce297f..4dc8618f0d0 100644 --- a/homeassistant/components/tts/yandextts.py +++ b/homeassistant/components/tts/yandextts.py @@ -77,7 +77,7 @@ class YandexSpeechKitProvider(Provider): return SUPPORT_LANGUAGES @asyncio.coroutine - def async_get_tts_audio(self, message, language): + def async_get_tts_audio(self, message, language, options=None): """Load TTS from yandex.""" websession = async_get_clientsession(self.hass) diff --git a/tests/components/tts/test_init.py b/tests/components/tts/test_init.py index 715b98c4740..023c05edf99 100644 --- a/tests/components/tts/test_init.py +++ b/tests/components/tts/test_init.py @@ -1,7 +1,8 @@ """The tests for the TTS component.""" +import ctypes import os import shutil -from unittest.mock import patch +from unittest.mock import patch, PropertyMock import requests @@ -82,11 +83,11 @@ class TestTTS(object): assert calls[0].data[ATTR_MEDIA_CONTENT_TYPE] == MEDIA_TYPE_MUSIC assert calls[0].data[ATTR_MEDIA_CONTENT_ID].find( "/api/tts_proxy/265944c108cbb00b2a621be5930513e03a0bb2cd" - "_en_demo.mp3") \ + "_en_-_demo.mp3") \ != -1 assert os.path.isfile(os.path.join( self.default_tts_cache, - "265944c108cbb00b2a621be5930513e03a0bb2cd_en_demo.mp3")) + "265944c108cbb00b2a621be5930513e03a0bb2cd_en_-_demo.mp3")) def test_setup_component_and_test_service_with_config_language(self): """Setup the demo platform and call service.""" @@ -111,11 +112,11 @@ class TestTTS(object): assert calls[0].data[ATTR_MEDIA_CONTENT_TYPE] == MEDIA_TYPE_MUSIC assert calls[0].data[ATTR_MEDIA_CONTENT_ID].find( "/api/tts_proxy/265944c108cbb00b2a621be5930513e03a0bb2cd" - "_de_demo.mp3") \ + "_de_-_demo.mp3") \ != -1 assert os.path.isfile(os.path.join( self.default_tts_cache, - "265944c108cbb00b2a621be5930513e03a0bb2cd_de_demo.mp3")) + "265944c108cbb00b2a621be5930513e03a0bb2cd_de_-_demo.mp3")) def test_setup_component_and_test_service_with_wrong_conf_language(self): """Setup the demo platform and call service with wrong config.""" @@ -152,11 +153,11 @@ class TestTTS(object): assert calls[0].data[ATTR_MEDIA_CONTENT_TYPE] == MEDIA_TYPE_MUSIC assert calls[0].data[ATTR_MEDIA_CONTENT_ID].find( "/api/tts_proxy/265944c108cbb00b2a621be5930513e03a0bb2cd" - "_de_demo.mp3") \ + "_de_-_demo.mp3") \ != -1 assert os.path.isfile(os.path.join( self.default_tts_cache, - "265944c108cbb00b2a621be5930513e03a0bb2cd_de_demo.mp3")) + "265944c108cbb00b2a621be5930513e03a0bb2cd_de_-_demo.mp3")) def test_setup_component_test_service_with_wrong_service_language(self): """Setup the demo platform and call service.""" @@ -180,7 +181,106 @@ class TestTTS(object): assert len(calls) == 0 assert not os.path.isfile(os.path.join( self.default_tts_cache, - "265944c108cbb00b2a621be5930513e03a0bb2cd_lang_demo.mp3")) + "265944c108cbb00b2a621be5930513e03a0bb2cd_lang_-_demo.mp3")) + + def test_setup_component_and_test_service_with_service_options(self): + """Setup the demo platform and call service with options.""" + calls = mock_service(self.hass, DOMAIN_MP, SERVICE_PLAY_MEDIA) + + config = { + tts.DOMAIN: { + 'platform': 'demo', + } + } + + with assert_setup_component(1, tts.DOMAIN): + setup_component(self.hass, tts.DOMAIN, config) + + self.hass.services.call(tts.DOMAIN, 'demo_say', { + tts.ATTR_MESSAGE: "I person is on front of your door.", + tts.ATTR_LANGUAGE: "de", + tts.ATTR_OPTIONS: { + 'voice': 'alex' + } + }) + self.hass.block_till_done() + + opt_hash = ctypes.c_size_t(hash(frozenset({'voice': 'alex'}))).value + + assert len(calls) == 1 + assert calls[0].data[ATTR_MEDIA_CONTENT_TYPE] == MEDIA_TYPE_MUSIC + assert calls[0].data[ATTR_MEDIA_CONTENT_ID].find( + "/api/tts_proxy/265944c108cbb00b2a621be5930513e03a0bb2cd" + "_de_{0}_demo.mp3".format(opt_hash)) \ + != -1 + assert os.path.isfile(os.path.join( + self.default_tts_cache, + "265944c108cbb00b2a621be5930513e03a0bb2cd_de_{0}_demo.mp3".format( + opt_hash))) + + @patch('homeassistant.components.tts.demo.DemoProvider.default_options', + new_callable=PropertyMock(return_value={'voice': 'alex'})) + def test_setup_component_and_test_with_service_options_def(self, def_mock): + """Setup the demo platform and call service with default options.""" + calls = mock_service(self.hass, DOMAIN_MP, SERVICE_PLAY_MEDIA) + + config = { + tts.DOMAIN: { + 'platform': 'demo', + } + } + + with assert_setup_component(1, tts.DOMAIN): + setup_component(self.hass, tts.DOMAIN, config) + + self.hass.services.call(tts.DOMAIN, 'demo_say', { + tts.ATTR_MESSAGE: "I person is on front of your door.", + tts.ATTR_LANGUAGE: "de", + }) + self.hass.block_till_done() + + opt_hash = ctypes.c_size_t(hash(frozenset({'voice': 'alex'}))).value + + assert len(calls) == 1 + assert calls[0].data[ATTR_MEDIA_CONTENT_TYPE] == MEDIA_TYPE_MUSIC + assert calls[0].data[ATTR_MEDIA_CONTENT_ID].find( + "/api/tts_proxy/265944c108cbb00b2a621be5930513e03a0bb2cd" + "_de_{0}_demo.mp3".format(opt_hash)) \ + != -1 + assert os.path.isfile(os.path.join( + self.default_tts_cache, + "265944c108cbb00b2a621be5930513e03a0bb2cd_de_{0}_demo.mp3".format( + opt_hash))) + + def test_setup_component_and_test_service_with_service_options_wrong(self): + """Setup the demo platform and call service with wrong options.""" + calls = mock_service(self.hass, DOMAIN_MP, SERVICE_PLAY_MEDIA) + + config = { + tts.DOMAIN: { + 'platform': 'demo', + } + } + + with assert_setup_component(1, tts.DOMAIN): + setup_component(self.hass, tts.DOMAIN, config) + + self.hass.services.call(tts.DOMAIN, 'demo_say', { + tts.ATTR_MESSAGE: "I person is on front of your door.", + tts.ATTR_LANGUAGE: "de", + tts.ATTR_OPTIONS: { + 'speed': 1 + } + }) + self.hass.block_till_done() + + opt_hash = ctypes.c_size_t(hash(frozenset({'speed': 1}))).value + + assert len(calls) == 0 + assert not os.path.isfile(os.path.join( + self.default_tts_cache, + "265944c108cbb00b2a621be5930513e03a0bb2cd_de_{0}_demo.mp3".format( + opt_hash))) def test_setup_component_and_test_service_clear_cache(self): """Setup the demo platform and call service clear cache.""" @@ -203,14 +303,14 @@ class TestTTS(object): assert len(calls) == 1 assert os.path.isfile(os.path.join( self.default_tts_cache, - "265944c108cbb00b2a621be5930513e03a0bb2cd_en_demo.mp3")) + "265944c108cbb00b2a621be5930513e03a0bb2cd_en_-_demo.mp3")) self.hass.services.call(tts.DOMAIN, tts.SERVICE_CLEAR_CACHE, {}) self.hass.block_till_done() assert not os.path.isfile(os.path.join( self.default_tts_cache, - "265944c108cbb00b2a621be5930513e03a0bb2cd_en_demo.mp3")) + "265944c108cbb00b2a621be5930513e03a0bb2cd_en_-_demo.mp3")) def test_setup_component_and_test_service_with_receive_voice(self): """Setup the demo platform and call service and receive voice.""" @@ -278,7 +378,7 @@ class TestTTS(object): self.hass.start() url = ("{}/api/tts_proxy/265944c108cbb00b2a621be5930513e03a0bb2cd" - "_en_demo.mp3").format(self.hass.config.api.base_url) + "_en_-_demo.mp3").format(self.hass.config.api.base_url) req = requests.get(url) assert req.status_code == 404 @@ -297,7 +397,7 @@ class TestTTS(object): self.hass.start() url = ("{}/api/tts_proxy/265944dsk32c1b2a621be5930510bb2cd" - "_en_demo.mp3").format(self.hass.config.api.base_url) + "_en_-_demo.mp3").format(self.hass.config.api.base_url) req = requests.get(url) assert req.status_code == 404 @@ -324,7 +424,7 @@ class TestTTS(object): assert len(calls) == 1 assert not os.path.isfile(os.path.join( self.default_tts_cache, - "265944c108cbb00b2a621be5930513e03a0bb2cd_en_demo.mp3")) + "265944c108cbb00b2a621be5930513e03a0bb2cd_en_-_demo.mp3")) def test_setup_component_test_with_cache_call_service_without_cache(self): """Setup demo platform with cache and call service without cache.""" @@ -349,7 +449,7 @@ class TestTTS(object): assert len(calls) == 1 assert not os.path.isfile(os.path.join( self.default_tts_cache, - "265944c108cbb00b2a621be5930513e03a0bb2cd_en_demo.mp3")) + "265944c108cbb00b2a621be5930513e03a0bb2cd_en_-_demo.mp3")) def test_setup_component_test_with_cache_dir(self): """Setup demo platform with cache and call service without cache.""" @@ -358,7 +458,7 @@ class TestTTS(object): _, demo_data = self.demo_provider.get_tts_audio("bla", 'en') cache_file = os.path.join( self.default_tts_cache, - "265944c108cbb00b2a621be5930513e03a0bb2cd_en_demo.mp3") + "265944c108cbb00b2a621be5930513e03a0bb2cd_en_-_demo.mp3") os.mkdir(self.default_tts_cache) with open(cache_file, "wb") as voice_file: @@ -384,7 +484,7 @@ class TestTTS(object): assert len(calls) == 1 assert calls[0].data[ATTR_MEDIA_CONTENT_ID].find( "/api/tts_proxy/265944c108cbb00b2a621be5930513e03a0bb2cd" - "_en_demo.mp3") \ + "_en_-_demo.mp3") \ != -1 @patch('homeassistant.components.tts.demo.DemoProvider.get_tts_audio', @@ -414,7 +514,7 @@ class TestTTS(object): _, demo_data = self.demo_provider.get_tts_audio("bla", 'en') cache_file = os.path.join( self.default_tts_cache, - "265944c108cbb00b2a621be5930513e03a0bb2cd_en_demo.mp3") + "265944c108cbb00b2a621be5930513e03a0bb2cd_en_-_demo.mp3") os.mkdir(self.default_tts_cache) with open(cache_file, "wb") as voice_file: @@ -433,7 +533,7 @@ class TestTTS(object): self.hass.start() url = ("{}/api/tts_proxy/265944c108cbb00b2a621be5930513e03a0bb2cd" - "_en_demo.mp3").format(self.hass.config.api.base_url) + "_en_-_demo.mp3").format(self.hass.config.api.base_url) req = requests.get(url) assert req.status_code == 200