From f4b1a8e42d4f142151b252cd1ac214793fe01ba0 Mon Sep 17 00:00:00 2001 From: Tod Schmidt Date: Tue, 17 Apr 2018 09:24:54 -0400 Subject: [PATCH] Added web view for TTS to get url (#13882) * Added web view for to get url * Added web view for TTS to get url * Added web view for TTS to get url * Added web view for TTS to get url * Fixed test * added auth * Update __init__.py --- homeassistant/components/tts/__init__.py | 115 ++++++++++++++--------- tests/components/tts/test_init.py | 46 ++++++++- 2 files changed, 117 insertions(+), 44 deletions(-) diff --git a/homeassistant/components/tts/__init__.py b/homeassistant/components/tts/__init__.py index 17aa66ea825..999b584360c 100644 --- a/homeassistant/components/tts/__init__.py +++ b/homeassistant/components/tts/__init__.py @@ -37,6 +37,7 @@ ATTR_CACHE = 'cache' ATTR_LANGUAGE = 'language' ATTR_MESSAGE = 'message' ATTR_OPTIONS = 'options' +ATTR_PLATFORM = 'platform' CONF_CACHE = 'cache' CONF_CACHE_DIR = 'cache_dir' @@ -77,8 +78,7 @@ SCHEMA_SERVICE_SAY = vol.Schema({ SCHEMA_SERVICE_CLEAR_CACHE = vol.Schema({}) -@asyncio.coroutine -def async_setup(hass, config): +async def async_setup(hass, config): """Set up TTS.""" tts = SpeechManager(hass) @@ -88,27 +88,27 @@ def async_setup(hass, config): cache_dir = conf.get(CONF_CACHE_DIR, DEFAULT_CACHE_DIR) time_memory = conf.get(CONF_TIME_MEMORY, DEFAULT_TIME_MEMORY) - yield from tts.async_init_cache(use_cache, cache_dir, time_memory) + await tts.async_init_cache(use_cache, cache_dir, time_memory) except (HomeAssistantError, KeyError) as err: _LOGGER.error("Error on cache init %s", err) return False hass.http.register_view(TextToSpeechView(tts)) + hass.http.register_view(TextToSpeechUrlView(tts)) - @asyncio.coroutine - def async_setup_platform(p_type, p_config, disc_info=None): + async def async_setup_platform(p_type, p_config, disc_info=None): """Set up a TTS platform.""" - platform = yield from async_prepare_setup_platform( + platform = await async_prepare_setup_platform( hass, config, DOMAIN, p_type) if platform is None: return try: if hasattr(platform, 'async_get_engine'): - provider = yield from platform.async_get_engine( + provider = await platform.async_get_engine( hass, p_config) else: - provider = yield from hass.async_add_job( + provider = await hass.async_add_job( platform.get_engine, hass, p_config) if provider is None: @@ -120,8 +120,7 @@ def async_setup(hass, config): _LOGGER.exception("Error setting up platform %s", p_type) return - @asyncio.coroutine - def async_say_handle(service): + async def async_say_handle(service): """Service handle for say.""" entity_ids = service.data.get(ATTR_ENTITY_ID) message = service.data.get(ATTR_MESSAGE) @@ -130,7 +129,7 @@ def async_setup(hass, config): options = service.data.get(ATTR_OPTIONS) try: - url = yield from tts.async_get_url( + url = await tts.async_get_url( p_type, message, cache=cache, language=language, options=options ) @@ -146,7 +145,7 @@ def async_setup(hass, config): if entity_ids: data[ATTR_ENTITY_ID] = entity_ids - yield from hass.services.async_call( + await hass.services.async_call( DOMAIN_MP, SERVICE_PLAY_MEDIA, data, blocking=True) hass.services.async_register( @@ -157,12 +156,11 @@ def async_setup(hass, config): in config_per_platform(config, DOMAIN)] if setup_tasks: - yield from asyncio.wait(setup_tasks, loop=hass.loop) + await asyncio.wait(setup_tasks, loop=hass.loop) - @asyncio.coroutine - def async_clear_cache_handle(service): + async def async_clear_cache_handle(service): """Handle clear cache service call.""" - yield from tts.async_clear_cache() + await tts.async_clear_cache() hass.services.async_register( DOMAIN, SERVICE_CLEAR_CACHE, async_clear_cache_handle, @@ -185,8 +183,7 @@ class SpeechManager(object): self.file_cache = {} self.mem_cache = {} - @asyncio.coroutine - def async_init_cache(self, use_cache, cache_dir, time_memory): + async def async_init_cache(self, use_cache, cache_dir, time_memory): """Init config folder and load file cache.""" self.use_cache = use_cache self.time_memory = time_memory @@ -201,7 +198,7 @@ class SpeechManager(object): return cache_dir try: - self.cache_dir = yield from self.hass.async_add_job( + self.cache_dir = await self.hass.async_add_job( init_tts_cache_dir, cache_dir) except OSError as err: raise HomeAssistantError("Can't init cache dir {}".format(err)) @@ -222,15 +219,14 @@ class SpeechManager(object): return cache try: - cache_files = yield from self.hass.async_add_job(get_cache_files) + cache_files = await self.hass.async_add_job(get_cache_files) except OSError as err: raise HomeAssistantError("Can't read cache dir {}".format(err)) if cache_files: self.file_cache.update(cache_files) - @asyncio.coroutine - def async_clear_cache(self): + async def async_clear_cache(self): """Read file cache and delete files.""" self.mem_cache = {} @@ -243,7 +239,7 @@ class SpeechManager(object): _LOGGER.warning( "Can't remove cache file '%s': %s", filename, err) - yield from self.hass.async_add_job(remove_files) + await self.hass.async_add_job(remove_files) self.file_cache = {} @callback @@ -254,9 +250,8 @@ class SpeechManager(object): provider.name = engine self.providers[engine] = provider - @asyncio.coroutine - def async_get_url(self, engine, message, cache=None, language=None, - options=None): + async def async_get_url(self, engine, message, cache=None, language=None, + options=None): """Get URL for play message. This method is a coroutine. @@ -301,21 +296,20 @@ class SpeechManager(object): self.hass.async_add_job(self.async_file_to_mem(key)) # Load speech from provider into memory else: - filename = yield from self.async_get_tts_audio( + filename = await self.async_get_tts_audio( engine, key, message, use_cache, language, options) return "{}/api/tts_proxy/{}".format( self.hass.config.api.base_url, filename) - @asyncio.coroutine - def async_get_tts_audio(self, engine, key, message, cache, language, - options): + async def async_get_tts_audio(self, engine, key, message, cache, language, + options): """Receive TTS and store for view in cache. This method is a coroutine. """ provider = self.providers[engine] - extension, data = yield from provider.async_get_tts_audio( + extension, data = await provider.async_get_tts_audio( message, language, options) if data is None or extension is None: @@ -337,8 +331,7 @@ class SpeechManager(object): return filename - @asyncio.coroutine - def async_save_tts_audio(self, key, filename, data): + async def async_save_tts_audio(self, key, filename, data): """Store voice data to file and file_cache. This method is a coroutine. @@ -351,13 +344,12 @@ class SpeechManager(object): speech.write(data) try: - yield from self.hass.async_add_job(save_speech) + await self.hass.async_add_job(save_speech) self.file_cache[key] = filename except OSError: _LOGGER.error("Can't write %s", filename) - @asyncio.coroutine - def async_file_to_mem(self, key): + async def async_file_to_mem(self, key): """Load voice from file cache into memory. This method is a coroutine. @@ -374,7 +366,7 @@ class SpeechManager(object): return speech.read() try: - data = yield from self.hass.async_add_job(load_speech) + data = await self.hass.async_add_job(load_speech) except OSError: del self.file_cache[key] raise HomeAssistantError("Can't read {}".format(voice_file)) @@ -396,8 +388,7 @@ class SpeechManager(object): self.hass.loop.call_later(self.time_memory, async_remove_from_mem) - @asyncio.coroutine - def async_read_tts(self, filename): + async def async_read_tts(self, filename): """Read a voice file and return binary. This method is a coroutine. @@ -412,7 +403,7 @@ class SpeechManager(object): if key not in self.mem_cache: if key not in self.file_cache: raise HomeAssistantError("{} not in cache!".format(key)) - yield from self.async_file_to_mem(key) + await self.async_file_to_mem(key) content, _ = mimetypes.guess_type(filename) return (content, self.mem_cache[key][MEM_CACHE_VOICE]) @@ -490,6 +481,45 @@ class Provider(object): ft.partial(self.get_tts_audio, message, language, options=options)) +class TextToSpeechUrlView(HomeAssistantView): + """TTS view to get a url to a generated speech file.""" + + requires_auth = True + url = '/api/tts_get_url' + name = 'api:tts:geturl' + + def __init__(self, tts): + """Initialize a tts view.""" + self.tts = tts + + async def post(self, request): + """Generate speech and provide url.""" + try: + data = await request.json() + except ValueError: + return self.json_message('Invalid JSON specified', 400) + if not data.get(ATTR_PLATFORM) and data.get(ATTR_MESSAGE): + return self.json_message('Must specify platform and message', 400) + + p_type = data[ATTR_PLATFORM] + message = data[ATTR_MESSAGE] + cache = data.get(ATTR_CACHE) + language = data.get(ATTR_LANGUAGE) + options = data.get(ATTR_OPTIONS) + + try: + url = await self.tts.async_get_url( + p_type, message, cache=cache, language=language, + options=options + ) + resp = self.json({'url': url}, 200) + except HomeAssistantError as err: + _LOGGER.error("Error on init tts: %s", err) + resp = self.json({'error': err}, 400) + + return resp + + class TextToSpeechView(HomeAssistantView): """TTS view to serve a speech audio.""" @@ -501,11 +531,10 @@ class TextToSpeechView(HomeAssistantView): """Initialize a tts view.""" self.tts = tts - @asyncio.coroutine - def get(self, request, filename): + async def get(self, request, filename): """Start a get request.""" try: - content, data = yield from self.tts.async_read_tts(filename) + content, data = await self.tts.async_read_tts(filename) except HomeAssistantError as err: _LOGGER.error("Error on load tts: %s", err) return web.Response(status=404) diff --git a/tests/components/tts/test_init.py b/tests/components/tts/test_init.py index 7a15ed28f97..b6bfa430fd2 100644 --- a/tests/components/tts/test_init.py +++ b/tests/components/tts/test_init.py @@ -2,6 +2,7 @@ import ctypes import os import shutil +import json from unittest.mock import patch, PropertyMock import pytest @@ -353,7 +354,7 @@ class TestTTS(object): demo_data = tts.SpeechManager.write_tags( "265944c108cbb00b2a621be5930513e03a0bb2cd_en_-_demo.mp3", demo_data, self.demo_provider, - "I person is on front of your door.", 'en', None) + "AI person is in front of your door.", 'en', None) assert req.status_code == 200 assert req.content == demo_data @@ -562,3 +563,46 @@ class TestTTS(object): req = requests.get(url) assert req.status_code == 200 assert req.content == demo_data + + def test_setup_component_and_web_get_url(self): + """Setup the demo platform and receive wrong file from web.""" + config = { + tts.DOMAIN: { + 'platform': 'demo', + } + } + + with assert_setup_component(1, tts.DOMAIN): + setup_component(self.hass, tts.DOMAIN, config) + + self.hass.start() + + url = ("{}/api/tts_get_url").format(self.hass.config.api.base_url) + data = {'platform': 'demo', + 'message': "I person is on front of your door."} + + req = requests.post(url, data=json.dumps(data)) + assert req.status_code == 200 + response = json.loads(req.text) + assert response.get('url') == (("{}/api/tts_proxy/265944c108cbb00b2a62" + "1be5930513e03a0bb2cd_en_-_demo.mp3") + .format(self.hass.config.api.base_url)) + + def test_setup_component_and_web_get_url_bad_config(self): + """Setup the demo platform and receive wrong file from web.""" + config = { + tts.DOMAIN: { + 'platform': 'demo', + } + } + + with assert_setup_component(1, tts.DOMAIN): + setup_component(self.hass, tts.DOMAIN, config) + + self.hass.start() + + url = ("{}/api/tts_get_url").format(self.hass.config.api.base_url) + data = {'message': "I person is on front of your door."} + + req = requests.post(url, data=data) + assert req.status_code == 400