From 2dec38d8d4f2461448fc05de352e619bc09167dd Mon Sep 17 00:00:00 2001 From: Pascal Vizeli Date: Tue, 13 Dec 2016 08:23:08 +0100 Subject: [PATCH] TTS Component / Google speech platform (#4837) * TTS Component / Google speech platform * Change file backend handling / cache * Use mimetype / rename Provider function / allow cache on service call * Add a memcache for faster response * Add demo platform * First version of unittest * Address comments * improve error handling / address comments * Add google unittest & check http response code * Change url param handling * add test for other language * Change hash to sha256 for same hash on every os/hardware * add unittest for receive demo data * add test for error cases * Test case load from file to mem over aiohttp server * Use cache SpeechManager level, address other comments * Add service for clear cache * Update service.yaml * add support for spliting google message --- .../components/media_player/__init__.py | 2 +- homeassistant/components/tts/__init__.py | 421 ++++++++++++++++++ homeassistant/components/tts/demo.mp3 | Bin 0 -> 8256 bytes homeassistant/components/tts/demo.py | 29 ++ homeassistant/components/tts/google.py | 117 +++++ homeassistant/components/tts/services.yaml | 14 + requirements_all.txt | 3 + tests/components/tts/__init__.py | 1 + tests/components/tts/test_google.py | 199 +++++++++ tests/components/tts/test_init.py | 320 +++++++++++++ tests/test_util/aiohttp.py | 9 +- 11 files changed, 1110 insertions(+), 5 deletions(-) create mode 100644 homeassistant/components/tts/__init__.py create mode 100644 homeassistant/components/tts/demo.mp3 create mode 100644 homeassistant/components/tts/demo.py create mode 100644 homeassistant/components/tts/google.py create mode 100644 homeassistant/components/tts/services.yaml create mode 100644 tests/components/tts/__init__.py create mode 100644 tests/components/tts/test_google.py create mode 100644 tests/components/tts/test_init.py diff --git a/homeassistant/components/media_player/__init__.py b/homeassistant/components/media_player/__init__.py index fa2ecee4337..3dea75df874 100644 --- a/homeassistant/components/media_player/__init__.py +++ b/homeassistant/components/media_player/__init__.py @@ -162,7 +162,7 @@ MEDIA_PLAYER_MEDIA_SEEK_SCHEMA = MEDIA_PLAYER_SCHEMA.extend({ MEDIA_PLAYER_PLAY_MEDIA_SCHEMA = MEDIA_PLAYER_SCHEMA.extend({ vol.Required(ATTR_MEDIA_CONTENT_TYPE): cv.string, vol.Required(ATTR_MEDIA_CONTENT_ID): cv.string, - ATTR_MEDIA_ENQUEUE: cv.boolean, + vol.Optional(ATTR_MEDIA_ENQUEUE): cv.boolean, }) MEDIA_PLAYER_SELECT_SOURCE_SCHEMA = MEDIA_PLAYER_SCHEMA.extend({ diff --git a/homeassistant/components/tts/__init__.py b/homeassistant/components/tts/__init__.py new file mode 100644 index 00000000000..0e75de88cc5 --- /dev/null +++ b/homeassistant/components/tts/__init__.py @@ -0,0 +1,421 @@ +""" +Provide functionality to TTS. + +For more details about this component, please refer to the documentation at +https://home-assistant.io/components/tts/ +""" +import asyncio +import logging +import hashlib +import mimetypes +import os +import re + +from aiohttp import web +import voluptuous as vol + +from homeassistant.const import ATTR_ENTITY_ID +from homeassistant.bootstrap import async_prepare_setup_platform +from homeassistant.core import callback +from homeassistant.config import load_yaml_config_file +from homeassistant.components.http import HomeAssistantView +from homeassistant.components.media_player import ( + SERVICE_PLAY_MEDIA, MEDIA_TYPE_MUSIC, ATTR_MEDIA_CONTENT_ID, + ATTR_MEDIA_CONTENT_TYPE, DOMAIN as DOMAIN_MP) +from homeassistant.exceptions import HomeAssistantError +from homeassistant.helpers import config_per_platform +import homeassistant.helpers.config_validation as cv + +DOMAIN = 'tts' +DEPENDENCIES = ['http'] + +_LOGGER = logging.getLogger(__name__) + +MEM_CACHE_FILENAME = 'filename' +MEM_CACHE_VOICE = 'voice' + +CONF_LANG = 'language' +CONF_CACHE = 'cache' +CONF_CACHE_DIR = 'cache_dir' +CONF_TIME_MEMORY = 'time_memory' + +DEFAULT_CACHE = True +DEFAULT_CACHE_DIR = "tts" +DEFAULT_LANG = 'en' +DEFAULT_TIME_MEMORY = 300 + +SERVICE_SAY = 'say' +SERVICE_CLEAR_CACHE = 'clear_cache' + +ATTR_MESSAGE = 'message' +ATTR_CACHE = 'cache' + +_RE_VOICE_FILE = re.compile(r"([a-f0-9]{40})_([a-z]+)\.[a-z0-9]{3,4}") + +PLATFORM_SCHEMA = cv.PLATFORM_SCHEMA.extend({ + vol.Optional(CONF_LANG, default=DEFAULT_LANG): cv.string, + vol.Optional(CONF_CACHE, default=DEFAULT_CACHE): cv.boolean, + vol.Optional(CONF_CACHE_DIR, default=DEFAULT_CACHE_DIR): cv.string, + vol.Optional(CONF_TIME_MEMORY, default=DEFAULT_TIME_MEMORY): + vol.All(vol.Coerce(int), vol.Range(min=60, max=57600)), +}) + + +SCHEMA_SERVICE_SAY = vol.Schema({ + vol.Required(ATTR_MESSAGE): cv.string, + vol.Optional(ATTR_ENTITY_ID): cv.entity_ids, + vol.Optional(ATTR_CACHE): cv.boolean, +}) + +SCHEMA_SERVICE_CLEAR_CACHE = vol.Schema({}) + + +@asyncio.coroutine +def async_setup(hass, config): + """Setup TTS.""" + tts = SpeechManager(hass) + + try: + conf = config[DOMAIN][0] if len(config.get(DOMAIN, [])) > 0 else {} + use_cache = conf.get(CONF_CACHE, DEFAULT_CACHE) + cache_dir = conf.get(CONF_CACHE_DIR, DEFAULT_CACHE_DIR) + time_memory = conf.get(CONF_TIME_MEMORY, DEFAULT_LANG) + + yield from tts.async_init_cache(use_cache, cache_dir, time_memory) + except (HomeAssistantError, KeyError) as err: + _LOGGER.error("Error on cache init %s", err) + return False + + hass.http.register_view(TextToSpeechView(tts)) + + descriptions = yield from hass.loop.run_in_executor( + None, load_yaml_config_file, + os.path.join(os.path.dirname(__file__), 'services.yaml')) + + @asyncio.coroutine + def async_setup_platform(p_type, p_config, disc_info=None): + """Setup a tts platform.""" + platform = yield from async_prepare_setup_platform( + hass, config, DOMAIN, p_type) + if platform is None: + return + + try: + if hasattr(platform, 'async_get_engine'): + provider = yield from platform.async_get_engine( + hass, p_config) + else: + provider = yield from hass.loop.run_in_executor( + None, platform.get_engine, hass, p_config) + + if provider is None: + _LOGGER.error('Error setting up platform %s', p_type) + return + + tts.async_register_engine(p_type, provider, p_config) + except Exception: # pylint: disable=broad-except + _LOGGER.exception('Error setting up platform %s', p_type) + return + + @asyncio.coroutine + def async_say_handle(service): + """Service handle for say.""" + entity_ids = service.data.get(ATTR_ENTITY_ID) + message = service.data.get(ATTR_MESSAGE) + cache = service.data.get(ATTR_CACHE) + + try: + url = yield from tts.async_get_url( + p_type, message, cache=cache) + except HomeAssistantError as err: + _LOGGER.error("Error on init tts: %s", err) + return + + data = { + ATTR_MEDIA_CONTENT_ID: url, + ATTR_MEDIA_CONTENT_TYPE: MEDIA_TYPE_MUSIC, + } + + if entity_ids: + data[ATTR_ENTITY_ID] = entity_ids + + yield from hass.services.async_call( + DOMAIN_MP, SERVICE_PLAY_MEDIA, data, blocking=True) + + hass.services.async_register( + DOMAIN, "{}_{}".format(p_type, SERVICE_SAY), async_say_handle, + descriptions.get(SERVICE_SAY), schema=SCHEMA_SERVICE_SAY) + + setup_tasks = [async_setup_platform(p_type, p_config) for p_type, p_config + in config_per_platform(config, DOMAIN)] + + if setup_tasks: + yield from asyncio.wait(setup_tasks, loop=hass.loop) + + @asyncio.coroutine + def async_clear_cache_handle(service): + """Handle clear cache service call.""" + yield from tts.async_clear_cache() + + hass.services.async_register( + DOMAIN, SERVICE_CLEAR_CACHE, async_clear_cache_handle, + descriptions.get(SERVICE_CLEAR_CACHE), schema=SERVICE_CLEAR_CACHE) + + return True + + +class SpeechManager(object): + """Representation of a speech store.""" + + def __init__(self, hass): + """Initialize a speech store.""" + self.hass = hass + self.providers = {} + + self.use_cache = True + self.cache_dir = None + self.time_memory = None + self.file_cache = {} + self.mem_cache = {} + + @asyncio.coroutine + def async_init_cache(self, use_cache, cache_dir, time_memory): + """Init config folder and load file cache.""" + self.use_cache = use_cache + self.time_memory = time_memory + + def init_tts_cache_dir(cache_dir): + """Init cache folder.""" + if not os.path.isabs(cache_dir): + cache_dir = self.hass.config.path(cache_dir) + if not os.path.isdir(cache_dir): + _LOGGER.info("Create cache dir %s.", cache_dir) + os.mkdir(cache_dir) + return cache_dir + + try: + self.cache_dir = yield from self.hass.loop.run_in_executor( + None, init_tts_cache_dir, cache_dir) + except OSError as err: + raise HomeAssistantError( + "Can't init cache dir {}".format(err)) + + def get_cache_files(): + """Return a dict of given engine files.""" + cache = {} + + folder_data = os.listdir(self.cache_dir) + for file_data in folder_data: + record = _RE_VOICE_FILE.match(file_data) + if record: + key = "{}_{}".format(record.group(1), record.group(2)) + cache[key.lower()] = file_data.lower() + return cache + + try: + cache_files = yield from self.hass.loop.run_in_executor( + None, get_cache_files) + except OSError as err: + raise HomeAssistantError( + "Can't read cache dir {}".format(err)) + + if cache_files: + self.file_cache.update(cache_files) + + @asyncio.coroutine + def async_clear_cache(self): + """Read file cache and delete files.""" + self.mem_cache = {} + + def remove_files(): + """Remove files from filesystem.""" + for _, filename in self.file_cache.items(): + try: + os.remove(os.path.join(self.cache_dir), filename) + except OSError: + pass + + yield from self.hass.loop.run_in_executor(None, remove_files) + self.file_cache = {} + + @callback + def async_register_engine(self, engine, provider, config): + """Register a TTS provider.""" + provider.hass = self.hass + provider.language = config.get(CONF_LANG) + self.providers[engine] = provider + + @asyncio.coroutine + def async_get_url(self, engine, message, cache=None): + """Get URL for play message. + + This method is a coroutine. + """ + msg_hash = hashlib.sha1(bytes(message, 'utf-8')).hexdigest() + key = ("{}_{}".format(msg_hash, engine)).lower() + use_cache = cache if cache is not None else self.use_cache + + # is speech allready in memory + if key in self.mem_cache: + filename = self.mem_cache[key][MEM_CACHE_FILENAME] + # is file store in file cache + elif use_cache and key in self.file_cache: + filename = self.file_cache[key] + self.hass.async_add_job(self.async_file_to_mem(engine, key)) + # load speech from provider into memory + else: + filename = yield from self.async_get_tts_audio( + engine, key, message, use_cache) + + return "{}/api/tts_proxy/{}".format( + self.hass.config.api.base_url, filename) + + @asyncio.coroutine + def async_get_tts_audio(self, engine, key, message, cache): + """Receive TTS and store for view in cache. + + This method is a coroutine. + """ + provider = self.providers[engine] + extension, data = yield from provider.async_get_tts_audio(message) + + if data is None or extension is None: + raise HomeAssistantError( + "No TTS from {} for '{}'".format(engine, message)) + + # create file infos + filename = ("{}.{}".format(key, extension)).lower() + + # save to memory + self._async_store_to_memcache(key, filename, data) + + if cache: + self.hass.async_add_job( + self.async_save_tts_audio(key, filename, data)) + + return filename + + @asyncio.coroutine + def async_save_tts_audio(self, key, filename, data): + """Store voice data to file and file_cache. + + This method is a coroutine. + """ + voice_file = os.path.join(self.cache_dir, filename) + + def save_speech(): + """Store speech to filesystem.""" + with open(voice_file, 'wb') as speech: + speech.write(data) + + try: + yield from self.hass.loop.run_in_executor(None, save_speech) + self.file_cache[key] = filename + except OSError: + _LOGGER.error("Can't write %s", filename) + + @asyncio.coroutine + def async_file_to_mem(self, engine, key): + """Load voice from file cache into memory. + + This method is a coroutine. + """ + filename = self.file_cache.get(key) + if not filename: + raise HomeAssistantError("Key {} not in file cache!".format(key)) + + voice_file = os.path.join(self.cache_dir, filename) + + def load_speech(): + """Load a speech from filesystem.""" + with open(voice_file, 'rb') as speech: + return speech.read() + + try: + data = yield from self.hass.loop.run_in_executor(None, load_speech) + except OSError: + raise HomeAssistantError("Can't read {}".format(voice_file)) + + self._async_store_to_memcache(key, filename, data) + + @callback + def _async_store_to_memcache(self, key, filename, data): + """Store data to memcache and set timer to remove it.""" + self.mem_cache[key] = { + MEM_CACHE_FILENAME: filename, + MEM_CACHE_VOICE: data, + } + + @callback + def async_remove_from_mem(): + """Cleanup memcache.""" + self.mem_cache.pop(key) + + self.hass.loop.call_later(self.time_memory, async_remove_from_mem) + + @asyncio.coroutine + def async_read_tts(self, filename): + """Read a voice file and return binary. + + This method is a coroutine. + """ + record = _RE_VOICE_FILE.match(filename.lower()) + if not record: + raise HomeAssistantError("Wrong tts file format!") + + key = "{}_{}".format(record.group(1), record.group(2)) + + if key not in self.mem_cache: + if key not in self.file_cache: + raise HomeAssistantError("%s not in cache!", key) + engine = record.group(2) + yield from self.async_file_to_mem(engine, key) + + content, _ = mimetypes.guess_type(filename) + return (content, self.mem_cache[key][MEM_CACHE_VOICE]) + + +class Provider(object): + """Represent a single provider.""" + + hass = None + language = DEFAULT_LANG + + def get_tts_audio(self, message): + """Load tts audio file from provider.""" + raise NotImplementedError() + + @asyncio.coroutine + def async_get_tts_audio(self, message): + """Load tts audio file from provider. + + Return a tuple of file extension and data as bytes. + + This method is a coroutine. + """ + extension, data = yield from self.hass.loop.run_in_executor( + None, self.get_tts_audio, message) + return (extension, data) + + +class TextToSpeechView(HomeAssistantView): + """TTS view to serve an speech audio.""" + + requires_auth = False + url = "/api/tts_proxy/{filename}" + name = "api:tts:speech" + + def __init__(self, tts): + """Initialize a tts view.""" + self.tts = tts + + @asyncio.coroutine + def get(self, request, filename): + """Start a get request.""" + try: + content, data = yield from self.tts.async_read_tts(filename) + except HomeAssistantError as err: + _LOGGER.error("Error on load tts: %s", err) + return web.Response(status=404) + + return web.Response(body=data, content_type=content) diff --git a/homeassistant/components/tts/demo.mp3 b/homeassistant/components/tts/demo.mp3 new file mode 100644 index 0000000000000000000000000000000000000000..f34241c769856718331a5e523f7d9b8b026a2c30 GIT binary patch literal 8256 zcmeI#Wl)rFzX$MpX^_}mI;4@1E@_sM?hZi^kdj6PLAsaTg(Zb0B$bjfP*}QCM5UAt zkr2r}?EjfLXU@!d_ssL|TyJ*vn%U3(?(da1_D~%IKpOtA1N|5jCB;+v-?#OjXgd>D zGK3HjliX|m8R0UGj`7(HJI9hPG+HJaUKQxo@=oba%}il0v5RNzWJFQp@r4BZ3ux>G z5j(GPC8$J`z?6uP9ElqEG-8Xr_Q`4dR-A#a7!*^9c1|P(J7c3`Au)Yn9V}^1?Do6c zoDR!BKhG?V>8lx5&o7F55E$m1eOOW-x#U=+E<8`q>crw&O=;zmi0jXzI<3bxQ}MuE zuEhKZ76+V9LEEtA3;ZBMwaD0;lBB>3h&&IfRmQcQW+VJ>-tql=;94Uh>C7$Jn-z&M z-sUCzFvP2zZ%-_($)f{EAPYOOrze-#i58j2@dTDPSnOqJ+@svcB_u1JxB&sqC#CM_ z2nAL!?R-J2V9{bxg8nDsD_>l5|L73W?yF+cF0}Rt8in4?QC|*Frd4DrW!=53J8%9q za~6Kt6`WsldYbr{F^enwx_B)1cYy0EwgLhoZZmE^PZW0kr+ykg|1x0G2vHuZ%fFkN z@urq=-l;yNbB&!`exKondvY#Aos&i^2xjHv1Uy}Dh%Nd-NTm)7rH4yI^BmE z(*ctZ$ta(}o|`Q8YSxQSoAeCEDKFAY;CxD&&7w9Sln!zKZX}9s=uu|eSsGHBOV89BO4zAcroWiL=GvYxf8cKpTY>&EdY<*?`T_Rz zRcgKQntcLeRES)1RQ{EJ2D%MEFd>ZTU`V13=(nKkGp=qcOh%YmiC|Fh&@&iZUf`UUP_qkHJ?pPknueD91_w~#nH)F!sE z{Ga@3jM49aWGh4pEVZmLpM>G*&k1=$;(Xn72OvgBaP2QJ+W^Xr1GR!nt;e&?KB9vhAt7j=Wa60E}cRJ5e%k3jbvV&A5Ucf_h>@xn$* z8FL%+d9?NF(rrEPR?Sw39dk5|N5UoJ(J!iMp9M7U{3()^f0L?> zE@nO7$<{G?-OQa$|5b3NPP^#MrhXl3^7Dn`=d*BbZo`r5aYY}~s1RTZt?6#qv$@;K zUhol@BwfXy;(R(tRSO0{Wg-l*HSq`2DGV1ehy*(@xBBF89*whgFS^$!MMQbq|JiC| z(h2gm_2aWxFEe9?g;`q?k}9^GRrrDmKPe~&5Ly`ay>(YA6mk(XGb?$JUdCM7qN=z__rwY zl1UAmot&&)&tUqDe*-ffR@(IxRcmz=noCKhiLOZn&L&34#jfLgVeluWHek@kSK|rm zV4&{kw$t=8qU(u(O4<~0t(L+{8 zyxEzq8+U%4xrr0@Zq*Ua<9uVbL(Wi8o*U?_X4d=Z2|@B^uM1ldGI3x1m}_$+tVTKE z8`DNrAZQ8B=Vs1gss-h`fl@;@1?(hwzJmbP(svvwyHzd+r-$lX-2I4~r7V&5HO0}q z6Cw?M+!W#G3?l83u#FQ9GSV4x*EXsy>Of#5h)NUiv4|GI(e{bEve)%*PT zl;t1Rq^7DQTum^LkR#Hv+qj z>%L1>xiH|3%ui_=JXGw#`Q%Ve%6b5XLBZ`vR0eLhk{EpI-he-9Yfi6e8a9b-OKAS{ zWiVN?<@lq;9{j7|yZPOkQn@H%lF?LU#;{i^a)Dvyiq?djQK8+73wFYdG|qA(Rw$Jg zXNU@E0M4g_3_@!Gl3tXkZ(77FGbUR3;f12~Rz&X*|0_+`Lx<_HZ5Jh7xjNUMIa`owj3z5k$2(=`LVR&P2PBH} z+K)qBeCwPKz7V(7`r++RC^ZyXclpgvHDazq>~V|(D4FAgKW4FYgQr8AF|(+#^kjD6 zYSf=Hu6Ql;jvr8#Z$i;2^57O!&DJ+QK*^-!I6vshtMdfs3q!?45J#j^{SFS-l?LwV*`)(41PTsF2A=R zql-~h8!OVB%Ry-cPlE9$cv4X!-ZC0crkSDN3!dzHtxtkG7*ZoS9M1y#alRGA7>xp^ zM-fkHm*8c2DXQP_BsUu(njm4W;rj2zPnnCO4@i`v0GLI8*E|UszS5G+1x!YUTnIz; zSDk51)};F`9&o!^qH;=4T7fK3Y*Wk&fzJSzHF?_Rh4Wn?pYpDK9fUT)3lh<%V#474 z#%5+466q&vkAAYrv>Mad9EhrlL4TePv$defZ#1*xsch+~`dfxZuZqxLuiK+)r0GWr zxDoq(EtQA)Oh4m5oTkFtYL(^VfjHlfcAzUAn4v-3VU#JS5st14?TD%9ix^@vpTBQ_ z5`&Odef}={ZEUStO=0XLRz!Kf!Qs4av}A@;FUzA=TG}@j`AS5p^!JcW8Q!I5Ux4z^ z14HLPZ#vmqiFYwL-;2FuJ``9hLOt@U8E|;<)Vri}B<}}k5xa-2mF2TYYPG>I-4_ix zGr~ym>B0zCH8O9GM1AeSsM)TWaKXrK4KBC&$7&JtQsjasS7t9|K747jj1<0wA(8&R zi1QWLKxYZS&?IWi_3smc`DCe(VM0UaZ93Plo)L_58gLtr7xM@K(~5U^m2$({9GRzb*t&*3caozbp|1gjEtU@3_wig3;1=f(|0JW%e8ai!|y)h7!8?E z&7>X`4rVmiS`P$sGP-l6iO>tNzI*Yt=eOvy%pB}5oKHo2czyno7`^`$emz(xm-h3o z%5U>Zg&`DRlmw2E8wXJ;c$mTxrxVapQI7$%zq+bxPn|5UwQ46Srbt{o*RSlx$77`x zD*71NQneB!cT(oz#q~Y?`lgN@!ue#>YE16{61{hapK-m{{hhmYU$DcMi<8z*J&K(6 zTgdGJuZX6C75v_Ec%G6FAi{I`FxS{_l{-b4&P#uLHEuG?=UIJ>SJvi+G7e}Ae|2^_ zJ~fn$!ueqETh{CPDc_B*y7?~q$MFX+_)ZeqS;^fNxWfV&y=CTWA$~uVhgQ36Oi0v3 zRV8vs)vee$2;Jw4^0e`d=qXMB5X<2>_vE7QBHRJ1?>T78&#}L<#7l|u*`c1S%>V=V zo!Ew?29x%{14_+uq6!wHhWj*G$q;Q1ze)EVncrXUKw56zU=7e0Vu|hWc66EM*Bo%n zjpL%;-KN9#jGE&7f7|m&X%Uq^njV9zD(;2E ziE=W@N~JwIc^N|Z+eE)3ny9#CmaaJBTzJY1IbOBCe%0^vz6JYa6^q@(Vt29gN6udt z7Z0W@>wCG{j41;M!TiW-oG%NhuFwOSfPs0YY1-gkI9|f{TB4oCPN#Ct*2*W!h$wZE zmsq(k9swHdi7-bcS8oU5LTc65@4o9PF~?qb3typE2+;hAk@W@fQR@ET4JKAY|eF84j-Jh{j`jTr~4^`5IGbjvb;K_aZ z1(@r50=33_IUj@h5j}ka&Pa@Q(u|mJhVckNC5ltMpQgTvbCvxiNh1@o42$y_s3Cdv zph+vhq3-}YkBze=>f?vZ)224eyAem+FR+nUF=Cc7nnP`Qv<^93SO(C{J=oM%LF(xr z9iz8shr)UA_MWtmK*`>n+);ZN|C#=SPjrNG;no|Z*0~7IXN8m`RRSOFk%>krFRG_} zO{i{n>lZIgoOLqK+-mW|mpe_Nzj>Bt)hMb%iZ|5xaud0@%qM-@B*BrAM>+Vgope3) zRM;-IuU8nkcPab0h?``nPgIGqb~S2d4(AI%94oGUE98=|SxZ;dm$%`$@Y#1O0-tR* z&pYjVz5MVv7H8IMJ>sAfW+|wJOpC|+Dwm%Ql<(ZLYd@oHw7q<5OGn!W*(oCKuhNco zM25Z5d_OjIa$DlkYXs-tfedtA`zU1VrPAA$@A2VRhZ1hH+m5%ZuYPw!-<&1{D$rI| zzuI=-Xx%wnMB1Ac{1}gpEnb!rZ6hM;?q(+eMXvzrR!SkVFKs_`{95g_J$8G`QoZRl zRh@A>J! zgc%L%gc_IAYfM5{ZBmzc^2DAmM|VV}PHYFU>Hwaoq(T#y*GiMrnos7awC^2FV(ewe zfxi%_Y{e?qqC}^)D5)YNK5NTNCXIgUE^Na;>q1Wg=i5PYGb4a$b;Q+d&Vt@gZ}T1_ z(#m6~QEbWKre?>o@$3Al)dlbPuGC=7aO~oX1TRo=>S53=W{tjgJme)pt}tZ=};nX1*pK-T|T$2g3x!n3ccPU7~XXTK;jb%i_3 z)!+Zk!6VERJMG5#k0H6wi*f!fL*xw9jrs-aL{fa+=);~F)r&JAM5A+$x6xd?4<@}4 zpCS6?BdeO$mrvk0*x5rWg`u3FnxY?9oLA!QWmpfg9HTZyL3j;GZ9~}{Gc`D3_!Z8N zf+j*Mfo?0m*2mew4;!L-@@sgo*ehX(ppfq~E2E!GTK5T;Q!2zULI`CN{NDt0X+$hpg36Li#2|K@p9DE2Dgc>k0kE40 zZUs(AiDTc_)P@Z8fcIfo2WJ7gy(y`t+TqvV+~v45X0A_%a4hymM{`=$d(m5_- zx)wFEtDi#s@G7_#2bb^!gWjB=7$_? zwRK^xTl$BA_IhTRIg|*&aoSBu<&!^pw;fks2PxJ4`V)?QVNz3<-qju4ns@#yywl|% zXq~~yN77qY+I)Fw7_3=lfb$=6+!iPRrC9-z<54d6uB$8ad+;=V{@?`jcbs)#HR~bx zUWlAWW)L!FAFMVP2bxpP8M+G+>Uit#gWW6*s3I|ju*--=qnA{=zcD-|;gpIY9G#Rf zrd7;;vW4@d=ngU~K_h%1r<(IH>u0w|IrH z;6QzhT2zJ24=Ts;x~{5>-E)3)-VV&g6rA#wh!VOH#}U!SjPoIMRg)-SVh-`W4cdHQ zRd_f$@~fo1=*w81e31h*_1!a;9BzL7VnNO4EY+X>PP@cEB)=T_F!1O|+l|5aCWDal z`(p=r>86GWr$sFG>(_@Dm`96WEVnrH#YtOmJ}HC|T?X`LB7U`db}Wu7Sq}96IzQRH zz&dVZi&{{cxydeBFAUwygf`)cxgAmH@bK))tQhnw=E97U)%@2m{MU0zr+aOH&7$x~ z+V8OFkx}DlHPuV{Qs?lmIG-MBjmGVt2>ZKhu1*Eo7Ht*vJFkfCw&`qUw@-xHQhFq( zg&U8L`)IaxIdg`;KbL<$mg`P)J$){vhH|;p1!~B^Egq$1dV0^@Tc$%VkGIbP!V^08 z9;r+JGymBkpU}0y$0CGM<+OL9AzhvqhUZJlt*Je^>K`+8=5WpAb;GTiow@nv&$Q`< z>$dIL#5VdOs+dO!dxG!ee>`3YZThjCJCI*jN(|Z^9{6b!&_s4H_H9CLah3F_M?fmepX2o36Yi|MFpkyjf*Jmg(FwV9zqm}Q{3%xC{JHd4;69o}2yVL@K=E%#(EQ~iYd?L$J^WED$*9{rKe_QW5y zb3!E?o(J6AyqBZjlE^(5?U?y{+6hg9uYe1vxgNpjm zKRipd+9JCw6lYa1^GYMLGdruk5Q{Q2 z>y-I>b{4@mN`{6WLNPes3Y?Kw4D@FsS$TUTJ@Ua}txsbd%GW1`&_j$13dMKrdTU3G z8YHoFSg9G_(}kwW*=}+{jTF2fU17{w12T6E^wOECJ>u`7Z;qKEO8=#rzpoYafOmbR8~l5LnveN}YGdKm2N3C+INulIcwI28$@P>#gwQ*G}Q zM$&*-H*1?WXEQ$Q%rjXD-?igB8!J)_;6P@rqDjimVL}niOLGr|h15T>k;ZT;?wxmI^7JXLL8nVt0Q=?V2Q} zcqk@zZf+l4t+_BLR=+VmFqNRuB4eRoV5NC|6?c98*V!r5Y0C`KHV<1wWJ#>2Ga2Lj z049@aFW~E7yC*e6SMKmWg(M;?`kOqYVBQHNOTyAI`_0s$_UKU&P5ppa@06^M_&mKb zDRNqvBcW2Q$$$}y?R>H^CifkAQ{NdW82qt~jbccw^ZH7afW?Me=E9u#$yXheu93Hj5_Z`l}lY`@tTZ*ZcIv9|e3WM!tH|Nb7iJEQB?qF{#< zL24DoB;#%8>jzV^di~2XH%`K5Di=$yVm*t&OE}*E>PTb;%{Zmkxsd=u-drWOCAWx)#^!%u9Rbz<8joa5*N8Uuu}u z){zj?don;uif(S76zb&&V^>9XMf4-;CC8VvGQU0jJps=3PKbkyW>9{Pe>C77c#QJ{ zA#?VUKsN!OOu~L(iWf!jj{%|jL;0oZ!NJ1T{>csyUL>9IGJXun>d!$k;`$6QxTQHF zHX*7D#%KDpCQOUXswzE8?{6d9IC(M?EAqF+vHeGIRe>>0Ou}279}X@SX92oi={;|N2kd`ugwu|F1Is&A$J%|6dPkmL32A literal 0 HcmV?d00001 diff --git a/homeassistant/components/tts/demo.py b/homeassistant/components/tts/demo.py new file mode 100644 index 00000000000..a63bd6373ea --- /dev/null +++ b/homeassistant/components/tts/demo.py @@ -0,0 +1,29 @@ +""" +Support for the demo speech service. + +For more details about this component, please refer to the documentation at +https://home-assistant.io/components/demo/ +""" +import os + +from homeassistant.components.tts import Provider + + +def get_engine(hass, config): + """Setup Demo speech component.""" + return DemoProvider() + + +class DemoProvider(Provider): + """Demo speech api provider.""" + + def get_tts_audio(self, message): + """Load TTS from demo.""" + filename = os.path.join(os.path.dirname(__file__), "demo.mp3") + try: + with open(filename, 'rb') as voice: + data = voice.read() + except OSError: + return + + return ("mp3", data) diff --git a/homeassistant/components/tts/google.py b/homeassistant/components/tts/google.py new file mode 100644 index 00000000000..b271b2468d1 --- /dev/null +++ b/homeassistant/components/tts/google.py @@ -0,0 +1,117 @@ +""" +Support for the google speech service. + +For more details about this component, please refer to the documentation at +https://home-assistant.io/components/tts/google/ +""" +import asyncio +import logging +import re + +import aiohttp +import async_timeout +import yarl + +from homeassistant.components.tts import Provider +from homeassistant.helpers.aiohttp_client import async_get_clientsession + +REQUIREMENTS = ["gTTS-token==1.1.1"] + +_LOGGER = logging.getLogger(__name__) + +GOOGLE_SPEECH_URL = "http://translate.google.com/translate_tts" +MESSAGE_SIZE = 148 + + +@asyncio.coroutine +def async_get_engine(hass, config): + """Setup Google speech component.""" + return GoogleProvider(hass) + + +class GoogleProvider(Provider): + """Google speech api provider.""" + + def __init__(self, hass): + """Init Google TTS service.""" + self.hass = hass + self.headers = { + 'Referer': "http://translate.google.com/", + 'User-Agent': ("Mozilla/5.0 (Windows NT 10.0; WOW64) " + "AppleWebKit/537.36 (KHTML, like Gecko) " + "Chrome/47.0.2526.106 Safari/537.36") + } + + @asyncio.coroutine + def async_get_tts_audio(self, message): + """Load TTS from google.""" + from gtts_token import gtts_token + + token = gtts_token.Token() + websession = async_get_clientsession(self.hass) + message_parts = self._split_message_to_parts(message) + + data = b'' + for idx, part in enumerate(message_parts): + part_token = yield from self.hass.loop.run_in_executor( + None, token.calculate_token, part) + + url_param = { + 'ie': 'UTF-8', + 'tl': self.language, + 'q': yarl.quote(part), + 'tk': part_token, + 'total': len(message_parts), + 'idx': idx, + 'client': 'tw-ob', + 'textlen': len(part), + } + + request = None + try: + with async_timeout.timeout(10, loop=self.hass.loop): + request = yield from websession.get( + GOOGLE_SPEECH_URL, params=url_param, + headers=self.headers + ) + + if request.status != 200: + _LOGGER.error("Error %d on load url %s", request.code, + request.url) + return (None, None) + data += yield from request.read() + + except (asyncio.TimeoutError, aiohttp.errors.ClientError): + _LOGGER.error("Timeout for google speech.") + return (None, None) + + finally: + if request is not None: + yield from request.release() + + return ("mp3", data) + + @staticmethod + def _split_message_to_parts(message): + """Split message into single parts.""" + if len(message) <= MESSAGE_SIZE: + return [message] + + punc = "!()[]?.,;:" + punc_list = [re.escape(c) for c in punc] + pattern = '|'.join(punc_list) + parts = re.split(pattern, message) + + def split_by_space(fullstring): + """Split a string by space.""" + if len(fullstring) > MESSAGE_SIZE: + idx = fullstring.rfind(' ', 0, MESSAGE_SIZE) + return [fullstring[:idx]] + split_by_space(fullstring[idx:]) + else: + return [fullstring] + + msg_parts = [] + for part in parts: + msg_parts += split_by_space(part) + + return [msg for msg in msg_parts if len(msg) > 0] diff --git a/homeassistant/components/tts/services.yaml b/homeassistant/components/tts/services.yaml new file mode 100644 index 00000000000..aba1334da87 --- /dev/null +++ b/homeassistant/components/tts/services.yaml @@ -0,0 +1,14 @@ +say: + description: Say some things on a media player. + + fields: + entity_id: + description: Name(s) of media player entities + example: 'media_player.floor' + + message: + description: Text to speak on devices + example: 'My name is hanna' + +clear_cache: + description: Remove cache files and RAM cache. diff --git a/requirements_all.txt b/requirements_all.txt index 799b2b29b11..c85e3285e7e 100644 --- a/requirements_all.txt +++ b/requirements_all.txt @@ -130,6 +130,9 @@ freesms==0.1.1 # homeassistant.components.conversation fuzzywuzzy==0.14.0 +# homeassistant.components.tts.google +gTTS-token==1.1.1 + # homeassistant.components.device_tracker.bluetooth_le_tracker # gattlib==0.20150805 diff --git a/tests/components/tts/__init__.py b/tests/components/tts/__init__.py new file mode 100644 index 00000000000..f5eb0731409 --- /dev/null +++ b/tests/components/tts/__init__.py @@ -0,0 +1 @@ +"""The tests for tts platforms.""" diff --git a/tests/components/tts/test_google.py b/tests/components/tts/test_google.py new file mode 100644 index 00000000000..623a96f1dfb --- /dev/null +++ b/tests/components/tts/test_google.py @@ -0,0 +1,199 @@ +"""The tests for the Google speech platform.""" +import asyncio +import os +import shutil +from unittest.mock import patch + +import homeassistant.components.tts as tts +from homeassistant.components.media_player import ( + SERVICE_PLAY_MEDIA, ATTR_MEDIA_CONTENT_ID, DOMAIN as DOMAIN_MP) +from homeassistant.bootstrap import setup_component + +from tests.common import ( + get_test_home_assistant, assert_setup_component, mock_service) + + +class TestTTSGooglePlatform(object): + """Test the Google speech component.""" + + def setup_method(self): + """Setup things to be run when tests are started.""" + self.hass = get_test_home_assistant() + + self.url = "http://translate.google.com/translate_tts" + self.url_param = { + 'tl': 'en', + 'q': 'I%20person%20is%20on%20front%20of%20your%20door.', + 'tk': 5, + 'client': 'tw-ob', + 'textlen': 34, + 'total': 1, + 'idx': 0, + 'ie': 'UTF-8', + } + + def teardown_method(self): + """Stop everything that was started.""" + default_tts = self.hass.config.path(tts.DEFAULT_CACHE_DIR) + if os.path.isdir(default_tts): + shutil.rmtree(default_tts) + + self.hass.stop() + + def test_setup_component(self): + """Test setup component.""" + config = { + tts.DOMAIN: { + 'platform': 'google', + } + } + + with assert_setup_component(1, tts.DOMAIN): + setup_component(self.hass, tts.DOMAIN, config) + + @patch('gtts_token.gtts_token.Token.calculate_token', autospec=True, + return_value=5) + def test_service_say(self, mock_calculate, aioclient_mock): + """Test service call say.""" + calls = mock_service(self.hass, DOMAIN_MP, SERVICE_PLAY_MEDIA) + + aioclient_mock.get( + self.url, params=self.url_param, status=200, content=b'test') + + config = { + tts.DOMAIN: { + 'platform': 'google', + } + } + + with assert_setup_component(1, tts.DOMAIN): + setup_component(self.hass, tts.DOMAIN, config) + + self.hass.services.call(tts.DOMAIN, 'google_say', { + tts.ATTR_MESSAGE: "I person is on front of your door.", + }) + self.hass.block_till_done() + + assert len(calls) == 1 + assert len(aioclient_mock.mock_calls) == 1 + assert calls[0].data[ATTR_MEDIA_CONTENT_ID].find(".mp3") != -1 + + @patch('gtts_token.gtts_token.Token.calculate_token', autospec=True, + return_value=5) + def test_service_say_german(self, mock_calculate, aioclient_mock): + """Test service call say with german code.""" + calls = mock_service(self.hass, DOMAIN_MP, SERVICE_PLAY_MEDIA) + + self.url_param['tl'] = 'de' + aioclient_mock.get( + self.url, params=self.url_param, status=200, content=b'test') + + config = { + tts.DOMAIN: { + 'platform': 'google', + 'language': 'de', + } + } + + with assert_setup_component(1, tts.DOMAIN): + setup_component(self.hass, tts.DOMAIN, config) + + self.hass.services.call(tts.DOMAIN, 'google_say', { + tts.ATTR_MESSAGE: "I person is on front of your door.", + }) + self.hass.block_till_done() + + assert len(calls) == 1 + assert len(aioclient_mock.mock_calls) == 1 + + @patch('gtts_token.gtts_token.Token.calculate_token', autospec=True, + return_value=5) + def test_service_say_error(self, mock_calculate, aioclient_mock): + """Test service call say with http response 400.""" + calls = mock_service(self.hass, DOMAIN_MP, SERVICE_PLAY_MEDIA) + + aioclient_mock.get( + self.url, params=self.url_param, status=400, content=b'test') + + config = { + tts.DOMAIN: { + 'platform': 'google', + } + } + + with assert_setup_component(1, tts.DOMAIN): + setup_component(self.hass, tts.DOMAIN, config) + + self.hass.services.call(tts.DOMAIN, 'google_say', { + tts.ATTR_MESSAGE: "I person is on front of your door.", + }) + self.hass.block_till_done() + + assert len(calls) == 0 + assert len(aioclient_mock.mock_calls) == 1 + + @patch('gtts_token.gtts_token.Token.calculate_token', autospec=True, + return_value=5) + def test_service_say_timeout(self, mock_calculate, aioclient_mock): + """Test service call say with http timeout.""" + calls = mock_service(self.hass, DOMAIN_MP, SERVICE_PLAY_MEDIA) + + aioclient_mock.get( + self.url, params=self.url_param, exc=asyncio.TimeoutError()) + + config = { + tts.DOMAIN: { + 'platform': 'google', + } + } + + with assert_setup_component(1, tts.DOMAIN): + setup_component(self.hass, tts.DOMAIN, config) + + self.hass.services.call(tts.DOMAIN, 'google_say', { + tts.ATTR_MESSAGE: "I person is on front of your door.", + }) + self.hass.block_till_done() + + assert len(calls) == 0 + assert len(aioclient_mock.mock_calls) == 1 + + @patch('gtts_token.gtts_token.Token.calculate_token', autospec=True, + return_value=5) + def test_service_say_long_size(self, mock_calculate, aioclient_mock): + """Test service call say with a lot of text.""" + calls = mock_service(self.hass, DOMAIN_MP, SERVICE_PLAY_MEDIA) + + self.url_param['total'] = 9 + self.url_param['q'] = "I%20person%20is%20on%20front%20of%20your%20door" + self.url_param['textlen'] = 33 + for idx in range(0, 9): + self.url_param['idx'] = idx + aioclient_mock.get( + self.url, params=self.url_param, status=200, content=b'test') + + config = { + tts.DOMAIN: { + 'platform': 'google', + } + } + + with assert_setup_component(1, tts.DOMAIN): + setup_component(self.hass, tts.DOMAIN, config) + + self.hass.services.call(tts.DOMAIN, 'google_say', { + tts.ATTR_MESSAGE: ("I person is on front of your door." + "I person is on front of your door." + "I person is on front of your door." + "I person is on front of your door." + "I person is on front of your door." + "I person is on front of your door." + "I person is on front of your door." + "I person is on front of your door." + "I person is on front of your door."), + }) + self.hass.block_till_done() + + assert len(calls) == 1 + assert len(aioclient_mock.mock_calls) == 9 + assert calls[0].data[ATTR_MEDIA_CONTENT_ID].find(".mp3") != -1 diff --git a/tests/components/tts/test_init.py b/tests/components/tts/test_init.py new file mode 100644 index 00000000000..fbdbddb8db5 --- /dev/null +++ b/tests/components/tts/test_init.py @@ -0,0 +1,320 @@ +"""The tests for the TTS component.""" +import os +import shutil +from unittest.mock import patch + +import requests + +import homeassistant.components.tts as tts +from homeassistant.components.tts.demo import DemoProvider +from homeassistant.components.media_player import ( + SERVICE_PLAY_MEDIA, MEDIA_TYPE_MUSIC, ATTR_MEDIA_CONTENT_ID, + ATTR_MEDIA_CONTENT_TYPE, DOMAIN as DOMAIN_MP) +from homeassistant.bootstrap import setup_component + +from tests.common import ( + get_test_home_assistant, assert_setup_component, mock_service) + + +class TestTTS(object): + """Test the Google speech component.""" + + def setup_method(self): + """Setup things to be run when tests are started.""" + self.hass = get_test_home_assistant() + self.demo_provider = DemoProvider() + self.default_tts_cache = self.hass.config.path(tts.DEFAULT_CACHE_DIR) + + def teardown_method(self): + """Stop everything that was started.""" + if os.path.isdir(self.default_tts_cache): + shutil.rmtree(self.default_tts_cache) + + self.hass.stop() + + def test_setup_component_demo(self): + """Setup the demo platform with defaults.""" + config = { + tts.DOMAIN: { + 'platform': 'demo', + } + } + + with assert_setup_component(1, tts.DOMAIN): + setup_component(self.hass, tts.DOMAIN, config) + + assert self.hass.services.has_service(tts.DOMAIN, 'demo_say') + assert self.hass.services.has_service(tts.DOMAIN, 'clear_cache') + + @patch('os.mkdir', side_effect=OSError(2, "No access")) + def test_setup_component_demo_no_access_cache_folder(self, mock_mkdir): + """Setup the demo platform with defaults.""" + config = { + tts.DOMAIN: { + 'platform': 'demo', + } + } + + assert not setup_component(self.hass, tts.DOMAIN, config) + + assert not self.hass.services.has_service(tts.DOMAIN, 'demo_say') + assert not self.hass.services.has_service(tts.DOMAIN, 'clear_cache') + + def test_setup_component_and_test_service(self): + """Setup the demo platform and call service.""" + calls = mock_service(self.hass, DOMAIN_MP, SERVICE_PLAY_MEDIA) + + config = { + tts.DOMAIN: { + 'platform': 'demo', + } + } + + with assert_setup_component(1, tts.DOMAIN): + setup_component(self.hass, tts.DOMAIN, config) + + self.hass.services.call(tts.DOMAIN, 'demo_say', { + tts.ATTR_MESSAGE: "I person is on front of your door.", + }) + self.hass.block_till_done() + + assert len(calls) == 1 + assert calls[0].data[ATTR_MEDIA_CONTENT_TYPE] == MEDIA_TYPE_MUSIC + assert calls[0].data[ATTR_MEDIA_CONTENT_ID].find( + "/api/tts_proxy/265944c108cbb00b2a621be5930513e03a0bb2cd" + "_demo.mp3") \ + != -1 + assert os.path.isfile(os.path.join( + self.default_tts_cache, + "265944c108cbb00b2a621be5930513e03a0bb2cd_demo.mp3")) + + def test_setup_component_and_test_service_clear_cache(self): + """Setup the demo platform and call service clear cache.""" + calls = mock_service(self.hass, DOMAIN_MP, SERVICE_PLAY_MEDIA) + + config = { + tts.DOMAIN: { + 'platform': 'demo', + } + } + + with assert_setup_component(1, tts.DOMAIN): + setup_component(self.hass, tts.DOMAIN, config) + + self.hass.services.call(tts.DOMAIN, 'demo_say', { + tts.ATTR_MESSAGE: "I person is on front of your door.", + }) + self.hass.block_till_done() + + assert len(calls) == 1 + assert os.path.isfile(os.path.join( + self.default_tts_cache, + "265944c108cbb00b2a621be5930513e03a0bb2cd_demo.mp3")) + + self.hass.services.call(tts.DOMAIN, tts.SERVICE_CLEAR_CACHE, {}) + self.hass.block_till_done() + + assert not os.path.isfile(os.path.join( + self.default_tts_cache, + "265944c108cbb00b2a621be5930513e03a0bb2cd_demo.mp3")) + + def test_setup_component_and_test_service_with_receive_voice(self): + """Setup the demo platform and call service and receive voice.""" + calls = mock_service(self.hass, DOMAIN_MP, SERVICE_PLAY_MEDIA) + + config = { + tts.DOMAIN: { + 'platform': 'demo', + } + } + + with assert_setup_component(1, tts.DOMAIN): + setup_component(self.hass, tts.DOMAIN, config) + + self.hass.start() + + self.hass.services.call(tts.DOMAIN, 'demo_say', { + tts.ATTR_MESSAGE: "I person is on front of your door.", + }) + self.hass.block_till_done() + + assert len(calls) == 1 + req = requests.get(calls[0].data[ATTR_MEDIA_CONTENT_ID]) + _, demo_data = self.demo_provider.get_tts_audio("bla") + assert req.status_code == 200 + assert req.content == demo_data + + def test_setup_component_and_web_view_wrong_file(self): + """Setup the demo platform and receive wrong file from web.""" + config = { + tts.DOMAIN: { + 'platform': 'demo', + } + } + + with assert_setup_component(1, tts.DOMAIN): + setup_component(self.hass, tts.DOMAIN, config) + + self.hass.start() + + url = ("{}/api/tts_proxy/265944c108cbb00b2a621be5930513e03a0bb2cd" + "_demo.mp3").format(self.hass.config.api.base_url) + + req = requests.get(url) + assert req.status_code == 404 + + def test_setup_component_and_web_view_wrong_filename(self): + """Setup the demo platform and receive wrong filename from web.""" + config = { + tts.DOMAIN: { + 'platform': 'demo', + } + } + + with assert_setup_component(1, tts.DOMAIN): + setup_component(self.hass, tts.DOMAIN, config) + + self.hass.start() + + url = ("{}/api/tts_proxy/265944dsk32c1b2a621be5930510bb2cd" + "_demo.mp3").format(self.hass.config.api.base_url) + + req = requests.get(url) + assert req.status_code == 404 + + def test_setup_component_test_without_cache(self): + """Setup demo platform without cache.""" + calls = mock_service(self.hass, DOMAIN_MP, SERVICE_PLAY_MEDIA) + + config = { + tts.DOMAIN: { + 'platform': 'demo', + 'cache': False, + } + } + + with assert_setup_component(1, tts.DOMAIN): + setup_component(self.hass, tts.DOMAIN, config) + + self.hass.services.call(tts.DOMAIN, 'demo_say', { + tts.ATTR_MESSAGE: "I person is on front of your door.", + }) + self.hass.block_till_done() + + assert len(calls) == 1 + assert not os.path.isfile(os.path.join( + self.default_tts_cache, + "265944c108cbb00b2a621be5930513e03a0bb2cd_demo.mp3")) + + def test_setup_component_test_with_cache_call_service_without_cache(self): + """Setup demo platform with cache and call service without cache.""" + calls = mock_service(self.hass, DOMAIN_MP, SERVICE_PLAY_MEDIA) + + config = { + tts.DOMAIN: { + 'platform': 'demo', + 'cache': True, + } + } + + with assert_setup_component(1, tts.DOMAIN): + setup_component(self.hass, tts.DOMAIN, config) + + self.hass.services.call(tts.DOMAIN, 'demo_say', { + tts.ATTR_MESSAGE: "I person is on front of your door.", + tts.ATTR_CACHE: False, + }) + self.hass.block_till_done() + + assert len(calls) == 1 + assert not os.path.isfile(os.path.join( + self.default_tts_cache, + "265944c108cbb00b2a621be5930513e03a0bb2cd_demo.mp3")) + + def test_setup_component_test_with_cache_dir(self): + """Setup demo platform with cache and call service without cache.""" + calls = mock_service(self.hass, DOMAIN_MP, SERVICE_PLAY_MEDIA) + + _, demo_data = self.demo_provider.get_tts_audio("bla") + cache_file = os.path.join( + self.default_tts_cache, + "265944c108cbb00b2a621be5930513e03a0bb2cd_demo.mp3") + + os.mkdir(self.default_tts_cache) + with open(cache_file, "wb") as voice_file: + voice_file.write(demo_data) + + config = { + tts.DOMAIN: { + 'platform': 'demo', + 'cache': True, + } + } + + with assert_setup_component(1, tts.DOMAIN): + setup_component(self.hass, tts.DOMAIN, config) + + with patch('homeassistant.components.tts.demo.DemoProvider.' + 'get_tts_audio', return_value=None): + self.hass.services.call(tts.DOMAIN, 'demo_say', { + tts.ATTR_MESSAGE: "I person is on front of your door.", + }) + self.hass.block_till_done() + + assert len(calls) == 1 + assert calls[0].data[ATTR_MEDIA_CONTENT_ID].find( + "/api/tts_proxy/265944c108cbb00b2a621be5930513e03a0bb2cd" + "_demo.mp3") \ + != -1 + + @patch('homeassistant.components.tts.demo.DemoProvider.get_tts_audio', + return_value=None) + def test_setup_component_test_with_error_on_get_tts(self, tts_mock): + """Setup demo platform with wrong get_tts_audio.""" + calls = mock_service(self.hass, DOMAIN_MP, SERVICE_PLAY_MEDIA) + + config = { + tts.DOMAIN: { + 'platform': 'demo' + } + } + + with assert_setup_component(1, tts.DOMAIN): + setup_component(self.hass, tts.DOMAIN, config) + + self.hass.services.call(tts.DOMAIN, 'demo_say', { + tts.ATTR_MESSAGE: "I person is on front of your door.", + }) + self.hass.block_till_done() + + assert len(calls) == 0 + + def test_setup_component_load_cache_retrieve_without_mem_cache(self): + """Setup component and load cache and get without mem cache.""" + _, demo_data = self.demo_provider.get_tts_audio("bla") + cache_file = os.path.join( + self.default_tts_cache, + "265944c108cbb00b2a621be5930513e03a0bb2cd_demo.mp3") + + os.mkdir(self.default_tts_cache) + with open(cache_file, "wb") as voice_file: + voice_file.write(demo_data) + + config = { + tts.DOMAIN: { + 'platform': 'demo', + 'cache': True, + } + } + + with assert_setup_component(1, tts.DOMAIN): + setup_component(self.hass, tts.DOMAIN, config) + + self.hass.start() + + url = ("{}/api/tts_proxy/265944c108cbb00b2a621be5930513e03a0bb2cd" + "_demo.mp3").format(self.hass.config.api.base_url) + + req = requests.get(url) + assert req.status_code == 200 + assert req.content == demo_data diff --git a/tests/test_util/aiohttp.py b/tests/test_util/aiohttp.py index d6f0c80b435..4abf43a6e42 100644 --- a/tests/test_util/aiohttp.py +++ b/tests/test_util/aiohttp.py @@ -23,6 +23,7 @@ class AiohttpClientMocker: content=None, json=None, params=None, + headers=None, exc=None): """Mock a request.""" if json: @@ -65,8 +66,8 @@ class AiohttpClientMocker: return len(self.mock_calls) @asyncio.coroutine - def match_request(self, method, url, *, auth=None, params=None): \ - # pylint: disable=unused-variable + def match_request(self, method, url, *, auth=None, params=None, + headers=None): # pylint: disable=unused-variable """Match a request against pre-registered requests.""" for response in self._mocks: if response.match_request(method, url, params): @@ -76,8 +77,8 @@ class AiohttpClientMocker: raise self.exc return response - assert False, "No mock registered for {} {}".format(method.upper(), - url) + assert False, "No mock registered for {} {} {}".format(method.upper(), + url, params) class AiohttpClientMockResponse: