diff --git a/homeassistant/components/tts/__init__.py b/homeassistant/components/tts/__init__.py index 01d0a6a15e3..0f731a51485 100644 --- a/homeassistant/components/tts/__init__.py +++ b/homeassistant/components/tts/__init__.py @@ -5,7 +5,6 @@ For more details about this component, please refer to the documentation at https://home-assistant.io/components/tts/ """ import asyncio -import functools import hashlib import logging import mimetypes @@ -247,8 +246,6 @@ class SpeechManager(object): def async_register_engine(self, engine, provider, config): """Register a TTS provider.""" provider.hass = self.hass - if CONF_LANG in config: - provider.language = config.get(CONF_LANG) self.providers[engine] = provider @asyncio.coroutine @@ -257,9 +254,16 @@ class SpeechManager(object): This method is a coroutine. """ + provider = self.providers[engine] + + language = language or provider.default_language + if language is None or \ + language not in provider.supported_languages: + raise HomeAssistantError("Not supported language {0}".format( + language)) + msg_hash = hashlib.sha1(bytes(message, 'utf-8')).hexdigest() - language_key = language or self.providers[engine].language - key = KEY_PATTERN.format(msg_hash, language_key, engine).lower() + key = KEY_PATTERN.format(msg_hash, language, engine).lower() use_cache = cache if cache is not None else self.use_cache # is speech allready in memory @@ -387,13 +391,22 @@ class Provider(object): """Represent a single provider.""" hass = None - language = None - def get_tts_audio(self, message, language=None): + @property + def default_language(self): + """Default language.""" + return None + + @property + def supported_languages(self): + """List of supported languages.""" + return None + + def get_tts_audio(self, message, language): """Load tts audio file from provider.""" raise NotImplementedError() - def async_get_tts_audio(self, message, language=None): + def async_get_tts_audio(self, message, language): """Load tts audio file from provider. Return a tuple of file extension and data as bytes. @@ -401,8 +414,7 @@ class Provider(object): This method must be run in the event loop and returns a coroutine. """ return self.hass.loop.run_in_executor( - None, - functools.partial(self.get_tts_audio, message, language=language)) + None, self.get_tts_audio, message, language) class TextToSpeechView(HomeAssistantView): diff --git a/homeassistant/components/tts/demo.py b/homeassistant/components/tts/demo.py index 68d49d58f78..88afa0643f2 100644 --- a/homeassistant/components/tts/demo.py +++ b/homeassistant/components/tts/demo.py @@ -6,28 +6,50 @@ https://home-assistant.io/components/demo/ """ import os -from homeassistant.components.tts import Provider +import voluptuous as vol + +from homeassistant.components.tts import Provider, PLATFORM_SCHEMA, CONF_LANG + +SUPPORT_LANGUAGES = [ + 'en', 'de' +] + +DEFAULT_LANG = 'en' + +PLATFORM_SCHEMA = PLATFORM_SCHEMA.extend({ + vol.Optional(CONF_LANG, default=DEFAULT_LANG): vol.In(SUPPORT_LANGUAGES), +}) def get_engine(hass, config): """Setup Demo speech component.""" - return DemoProvider() + return DemoProvider(config[CONF_LANG]) class DemoProvider(Provider): """Demo speech api provider.""" - def __init__(self): - """Initialize demo provider for TTS.""" - self.language = 'en' + def __init__(self, lang): + """Initialize demo provider.""" + self._lang = lang - def get_tts_audio(self, message, language=None): + @property + def default_language(self): + """Default language.""" + return self._lang + + @property + def supported_languages(self): + """List of supported languages.""" + return SUPPORT_LANGUAGES + + def get_tts_audio(self, message, language): """Load TTS from demo.""" filename = os.path.join(os.path.dirname(__file__), "demo.mp3") try: with open(filename, 'rb') as voice: data = voice.read() except OSError: - return + return (None, None) return ("mp3", data) diff --git a/homeassistant/components/tts/google.py b/homeassistant/components/tts/google.py index e1bb4e5e4e5..dc03013d4f1 100644 --- a/homeassistant/components/tts/google.py +++ b/homeassistant/components/tts/google.py @@ -42,15 +42,16 @@ PLATFORM_SCHEMA = PLATFORM_SCHEMA.extend({ @asyncio.coroutine def async_get_engine(hass, config): """Setup Google speech component.""" - return GoogleProvider(hass) + return GoogleProvider(hass, config[CONF_LANG]) class GoogleProvider(Provider): """Google speech api provider.""" - def __init__(self, hass): + def __init__(self, hass, lang): """Init Google TTS service.""" self.hass = hass + self._lang = lang self.headers = { 'Referer': "http://translate.google.com/", 'User-Agent': ("Mozilla/5.0 (Windows NT 10.0; WOW64) " @@ -58,8 +59,18 @@ class GoogleProvider(Provider): "Chrome/47.0.2526.106 Safari/537.36") } + @property + def default_language(self): + """Default language.""" + return self._lang + + @property + def supported_languages(self): + """List of supported languages.""" + return SUPPORT_LANGUAGES + @asyncio.coroutine - def async_get_tts_audio(self, message, language=None): + def async_get_tts_audio(self, message, language): """Load TTS from google.""" from gtts_token import gtts_token @@ -67,11 +78,6 @@ class GoogleProvider(Provider): websession = async_get_clientsession(self.hass) message_parts = self._split_message_to_parts(message) - # If language is not specified or is not supported - use the language - # from the config. - if language not in SUPPORT_LANGUAGES: - language = self.language - data = b'' for idx, part in enumerate(message_parts): part_token = yield from self.hass.loop.run_in_executor( diff --git a/homeassistant/components/tts/picotts.py b/homeassistant/components/tts/picotts.py index 366973813a2..28db88c03b0 100644 --- a/homeassistant/components/tts/picotts.py +++ b/homeassistant/components/tts/picotts.py @@ -29,18 +29,31 @@ def get_engine(hass, config): if shutil.which("pico2wave") is None: _LOGGER.error("'pico2wave' was not found") return False - return PicoProvider() + return PicoProvider(config[CONF_LANG]) class PicoProvider(Provider): """pico speech api provider.""" - def get_tts_audio(self, message, language=None): + def __init__(self, lang): + """Initialize pico provider.""" + self._lang = lang + + @property + def default_language(self): + """Default language.""" + return self._lang + + @property + def supported_languages(self): + """List of supported languages.""" + return SUPPORT_LANGUAGES + + def get_tts_audio(self, message, language): """Load TTS using pico2wave.""" - if language not in SUPPORT_LANGUAGES: - language = self.language with tempfile.NamedTemporaryFile(suffix='.wav', delete=False) as tmpf: fname = tmpf.name + cmd = ['pico2wave', '--wave', fname, '-l', language, message] subprocess.call(cmd) data = None @@ -52,6 +65,7 @@ class PicoProvider(Provider): return (None, None) finally: os.remove(fname) + if data: return ("wav", data) return (None, None) diff --git a/homeassistant/components/tts/voicerss.py b/homeassistant/components/tts/voicerss.py index 688ae7f6e25..2dda27b0c06 100644 --- a/homeassistant/components/tts/voicerss.py +++ b/homeassistant/components/tts/voicerss.py @@ -93,27 +93,34 @@ class VoiceRSSProvider(Provider): def __init__(self, hass, conf): """Init VoiceRSS TTS service.""" self.hass = hass - self.extension = conf.get(CONF_CODEC) + self._extension = conf[CONF_CODEC] + self._lang = conf[CONF_LANG] - self.form_data = { - 'key': conf.get(CONF_API_KEY), - 'hl': conf.get(CONF_LANG), - 'c': (conf.get(CONF_CODEC)).upper(), - 'f': conf.get(CONF_FORMAT), + self._form_data = { + 'key': conf[CONF_API_KEY], + 'hl': conf[CONF_LANG], + 'c': (conf[CONF_CODEC]).upper(), + 'f': conf[CONF_FORMAT], } + @property + def default_language(self): + """Default language.""" + return self._lang + + @property + def supported_languages(self): + """List of supported languages.""" + return SUPPORT_LANGUAGES + @asyncio.coroutine - def async_get_tts_audio(self, message, language=None): + def async_get_tts_audio(self, message, language): """Load TTS from voicerss.""" websession = async_get_clientsession(self.hass) - form_data = self.form_data.copy() + form_data = self._form_data.copy() form_data['src'] = message - - # If language is specified and supported - use it instead of the - # language in the config. - if language in SUPPORT_LANGUAGES: - form_data['hl'] = language + form_data['hl'] = language request = None try: @@ -141,4 +148,4 @@ class VoiceRSSProvider(Provider): if request is not None: yield from request.release() - return (self.extension, data) + return (self._extension, data) diff --git a/tests/components/tts/test_init.py b/tests/components/tts/test_init.py index 55381395313..715b98c4740 100644 --- a/tests/components/tts/test_init.py +++ b/tests/components/tts/test_init.py @@ -22,7 +22,7 @@ class TestTTS(object): def setup_method(self): """Setup things to be run when tests are started.""" self.hass = get_test_home_assistant() - self.demo_provider = DemoProvider() + self.demo_provider = DemoProvider('en') self.default_tts_cache = self.hass.config.path(tts.DEFAULT_CACHE_DIR) def teardown_method(self): @@ -95,7 +95,7 @@ class TestTTS(object): config = { tts.DOMAIN: { 'platform': 'demo', - 'language': 'lang' + 'language': 'de' } } @@ -111,11 +111,23 @@ class TestTTS(object): assert calls[0].data[ATTR_MEDIA_CONTENT_TYPE] == MEDIA_TYPE_MUSIC assert calls[0].data[ATTR_MEDIA_CONTENT_ID].find( "/api/tts_proxy/265944c108cbb00b2a621be5930513e03a0bb2cd" - "_lang_demo.mp3") \ + "_de_demo.mp3") \ != -1 assert os.path.isfile(os.path.join( self.default_tts_cache, - "265944c108cbb00b2a621be5930513e03a0bb2cd_lang_demo.mp3")) + "265944c108cbb00b2a621be5930513e03a0bb2cd_de_demo.mp3")) + + def test_setup_component_and_test_service_with_wrong_conf_language(self): + """Setup the demo platform and call service with wrong config.""" + config = { + tts.DOMAIN: { + 'platform': 'demo', + 'language': 'ru' + } + } + + with assert_setup_component(0, tts.DOMAIN): + setup_component(self.hass, tts.DOMAIN, config) def test_setup_component_and_test_service_with_service_language(self): """Setup the demo platform and call service.""" @@ -127,6 +139,35 @@ class TestTTS(object): } } + with assert_setup_component(1, tts.DOMAIN): + setup_component(self.hass, tts.DOMAIN, config) + + self.hass.services.call(tts.DOMAIN, 'demo_say', { + tts.ATTR_MESSAGE: "I person is on front of your door.", + tts.ATTR_LANGUAGE: "de", + }) + self.hass.block_till_done() + + assert len(calls) == 1 + assert calls[0].data[ATTR_MEDIA_CONTENT_TYPE] == MEDIA_TYPE_MUSIC + assert calls[0].data[ATTR_MEDIA_CONTENT_ID].find( + "/api/tts_proxy/265944c108cbb00b2a621be5930513e03a0bb2cd" + "_de_demo.mp3") \ + != -1 + assert os.path.isfile(os.path.join( + self.default_tts_cache, + "265944c108cbb00b2a621be5930513e03a0bb2cd_de_demo.mp3")) + + def test_setup_component_test_service_with_wrong_service_language(self): + """Setup the demo platform and call service.""" + calls = mock_service(self.hass, DOMAIN_MP, SERVICE_PLAY_MEDIA) + + config = { + tts.DOMAIN: { + 'platform': 'demo', + } + } + with assert_setup_component(1, tts.DOMAIN): setup_component(self.hass, tts.DOMAIN, config) @@ -136,13 +177,8 @@ class TestTTS(object): }) self.hass.block_till_done() - assert len(calls) == 1 - assert calls[0].data[ATTR_MEDIA_CONTENT_TYPE] == MEDIA_TYPE_MUSIC - assert calls[0].data[ATTR_MEDIA_CONTENT_ID].find( - "/api/tts_proxy/265944c108cbb00b2a621be5930513e03a0bb2cd" - "_lang_demo.mp3") \ - != -1 - assert os.path.isfile(os.path.join( + assert len(calls) == 0 + assert not os.path.isfile(os.path.join( self.default_tts_cache, "265944c108cbb00b2a621be5930513e03a0bb2cd_lang_demo.mp3")) @@ -198,7 +234,7 @@ class TestTTS(object): assert len(calls) == 1 req = requests.get(calls[0].data[ATTR_MEDIA_CONTENT_ID]) - _, demo_data = self.demo_provider.get_tts_audio("bla") + _, demo_data = self.demo_provider.get_tts_audio("bla", 'en') assert req.status_code == 200 assert req.content == demo_data @@ -319,7 +355,7 @@ class TestTTS(object): """Setup demo platform with cache and call service without cache.""" calls = mock_service(self.hass, DOMAIN_MP, SERVICE_PLAY_MEDIA) - _, demo_data = self.demo_provider.get_tts_audio("bla") + _, demo_data = self.demo_provider.get_tts_audio("bla", 'en') cache_file = os.path.join( self.default_tts_cache, "265944c108cbb00b2a621be5930513e03a0bb2cd_en_demo.mp3") @@ -339,7 +375,7 @@ class TestTTS(object): setup_component(self.hass, tts.DOMAIN, config) with patch('homeassistant.components.tts.demo.DemoProvider.' - 'get_tts_audio', return_value=None): + 'get_tts_audio', return_value=(None, None)): self.hass.services.call(tts.DOMAIN, 'demo_say', { tts.ATTR_MESSAGE: "I person is on front of your door.", }) @@ -352,7 +388,7 @@ class TestTTS(object): != -1 @patch('homeassistant.components.tts.demo.DemoProvider.get_tts_audio', - return_value=None) + return_value=(None, None)) def test_setup_component_test_with_error_on_get_tts(self, tts_mock): """Setup demo platform with wrong get_tts_audio.""" calls = mock_service(self.hass, DOMAIN_MP, SERVICE_PLAY_MEDIA) @@ -375,7 +411,7 @@ class TestTTS(object): def test_setup_component_load_cache_retrieve_without_mem_cache(self): """Setup component and load cache and get without mem cache.""" - _, demo_data = self.demo_provider.get_tts_audio("bla") + _, demo_data = self.demo_provider.get_tts_audio("bla", 'en') cache_file = os.path.join( self.default_tts_cache, "265944c108cbb00b2a621be5930513e03a0bb2cd_en_demo.mp3")