Add type hints to tts (#64050)

* Add type hint for _get_cache_files

* Add type hint for _init_tts_cache_dir

* Add init type hints for async_clear_cache

* Add type hints to async_setup_platform

* Add type hints to async_register_engine

* Add type hints to self.providers

* Add type hints to _async_store_to_memcache

* Add type hints to async_file_to_mem

* Add full type hints

* Use tuple in async_read_tts

Co-authored-by: epenet <epenet@users.noreply.github.com>
This commit is contained in:
epenet 2022-01-14 12:35:29 +01:00 committed by GitHub
parent 44a686931e
commit c1692a324b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -10,7 +10,7 @@ import logging
import mimetypes import mimetypes
import os import os
import re import re
from typing import Optional, cast from typing import TYPE_CHECKING, Optional, cast
from aiohttp import web from aiohttp import web
import mutagen import mutagen
@ -38,6 +38,7 @@ from homeassistant.helpers import config_per_platform, discovery
import homeassistant.helpers.config_validation as cv import homeassistant.helpers.config_validation as cv
from homeassistant.helpers.network import get_url from homeassistant.helpers.network import get_url
from homeassistant.helpers.service import async_set_service_schema from homeassistant.helpers.service import async_set_service_schema
from homeassistant.helpers.typing import ConfigType, DiscoveryInfoType
from homeassistant.loader import async_get_integration from homeassistant.loader import async_get_integration
from homeassistant.setup import async_prepare_setup_platform from homeassistant.setup import async_prepare_setup_platform
from homeassistant.util.yaml import load_yaml from homeassistant.util.yaml import load_yaml
@ -117,7 +118,7 @@ SCHEMA_SERVICE_SAY = vol.Schema(
SCHEMA_SERVICE_CLEAR_CACHE = vol.Schema({}) SCHEMA_SERVICE_CLEAR_CACHE = vol.Schema({})
async def async_setup(hass, config): async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
"""Set up TTS.""" """Set up TTS."""
tts = SpeechManager(hass) tts = SpeechManager(hass)
@ -144,7 +145,11 @@ async def async_setup(hass, config):
dict, await hass.async_add_executor_job(load_yaml, str(services_yaml)) dict, await hass.async_add_executor_job(load_yaml, str(services_yaml))
) )
async def async_setup_platform(p_type, p_config=None, discovery_info=None): async def async_setup_platform(
p_type: str,
p_config: ConfigType | None = None,
discovery_info: DiscoveryInfoType | None = None,
) -> None:
"""Set up a TTS platform.""" """Set up a TTS platform."""
if p_config is None: if p_config is None:
p_config = {} p_config = {}
@ -175,7 +180,7 @@ async def async_setup(hass, config):
async def async_say_handle(service: ServiceCall) -> None: async def async_say_handle(service: ServiceCall) -> None:
"""Service handle for say.""" """Service handle for say."""
entity_ids = service.data[ATTR_ENTITY_ID] entity_ids = service.data[ATTR_ENTITY_ID]
message = service.data.get(ATTR_MESSAGE) message = service.data[ATTR_MESSAGE]
cache = service.data.get(ATTR_CACHE) cache = service.data.get(ATTR_CACHE)
language = service.data.get(ATTR_LANGUAGE) language = service.data.get(ATTR_LANGUAGE)
options = service.data.get(ATTR_OPTIONS) options = service.data.get(ATTR_OPTIONS)
@ -221,6 +226,7 @@ async def async_setup(hass, config):
setup_tasks = [ setup_tasks = [
asyncio.create_task(async_setup_platform(p_type, p_config)) asyncio.create_task(async_setup_platform(p_type, p_config))
for p_type, p_config in config_per_platform(config, DOMAIN) for p_type, p_config in config_per_platform(config, DOMAIN)
if p_type is not None
] ]
if setup_tasks: if setup_tasks:
@ -259,19 +265,21 @@ def _hash_options(options: dict) -> str:
class SpeechManager: class SpeechManager:
"""Representation of a speech store.""" """Representation of a speech store."""
def __init__(self, hass): def __init__(self, hass: HomeAssistant) -> None:
"""Initialize a speech store.""" """Initialize a speech store."""
self.hass = hass self.hass = hass
self.providers = {} self.providers: dict[str, Provider] = {}
self.use_cache = DEFAULT_CACHE self.use_cache = DEFAULT_CACHE
self.cache_dir = DEFAULT_CACHE_DIR self.cache_dir = DEFAULT_CACHE_DIR
self.time_memory = DEFAULT_TIME_MEMORY self.time_memory = DEFAULT_TIME_MEMORY
self.base_url = None self.base_url: str | None = None
self.file_cache = {} self.file_cache: dict[str, str] = {}
self.mem_cache = {} self.mem_cache: dict[str, dict[str, str | bytes]] = {}
async def async_init_cache(self, use_cache, cache_dir, time_memory, base_url): async def async_init_cache(
self, use_cache: bool, cache_dir: str, time_memory: int, base_url: str | None
) -> None:
"""Init config folder and load file cache.""" """Init config folder and load file cache."""
self.use_cache = use_cache self.use_cache = use_cache
self.time_memory = time_memory self.time_memory = time_memory
@ -294,7 +302,7 @@ class SpeechManager:
if cache_files: if cache_files:
self.file_cache.update(cache_files) self.file_cache.update(cache_files)
async def async_clear_cache(self): async def async_clear_cache(self) -> None:
"""Read file cache and delete files.""" """Read file cache and delete files."""
self.mem_cache = {} self.mem_cache = {}
@ -310,7 +318,9 @@ class SpeechManager:
self.file_cache = {} self.file_cache = {}
@callback @callback
def async_register_engine(self, engine, provider, config): def async_register_engine(
self, engine: str, provider: Provider, config: ConfigType
) -> None:
"""Register a TTS provider.""" """Register a TTS provider."""
provider.hass = self.hass provider.hass = self.hass
if provider.name is None: if provider.name is None:
@ -322,8 +332,13 @@ class SpeechManager:
) )
async def async_get_url_path( async def async_get_url_path(
self, engine, message, cache=None, language=None, options=None self,
): engine: str,
message: str,
cache: bool | None = None,
language: str | None = None,
options: dict | None = None,
) -> str:
"""Get URL for play message. """Get URL for play message.
This method is a coroutine. This method is a coroutine.
@ -362,7 +377,7 @@ class SpeechManager:
# Is speech already in memory # Is speech already in memory
if key in self.mem_cache: if key in self.mem_cache:
filename = self.mem_cache[key][MEM_CACHE_FILENAME] filename = cast(str, self.mem_cache[key][MEM_CACHE_FILENAME])
# Is file store in file cache # Is file store in file cache
elif use_cache and key in self.file_cache: elif use_cache and key in self.file_cache:
filename = self.file_cache[key] filename = self.file_cache[key]
@ -375,7 +390,15 @@ class SpeechManager:
return f"/api/tts_proxy/{filename}" return f"/api/tts_proxy/{filename}"
async def async_get_tts_audio(self, engine, key, message, cache, language, options): async def async_get_tts_audio(
self,
engine: str,
key: str,
message: str,
cache: bool,
language: str,
options: dict | None,
) -> str:
"""Receive TTS and store for view in cache. """Receive TTS and store for view in cache.
This method is a coroutine. This method is a coroutine.
@ -404,14 +427,14 @@ class SpeechManager:
return filename return filename
async def async_save_tts_audio(self, key, filename, data): async def async_save_tts_audio(self, key: str, filename: str, data: bytes) -> None:
"""Store voice data to file and file_cache. """Store voice data to file and file_cache.
This method is a coroutine. This method is a coroutine.
""" """
voice_file = os.path.join(self.cache_dir, filename) voice_file = os.path.join(self.cache_dir, filename)
def save_speech(): def save_speech() -> None:
"""Store speech to filesystem.""" """Store speech to filesystem."""
with open(voice_file, "wb") as speech: with open(voice_file, "wb") as speech:
speech.write(data) speech.write(data)
@ -422,7 +445,7 @@ class SpeechManager:
except OSError as err: except OSError as err:
_LOGGER.error("Can't write %s: %s", filename, err) _LOGGER.error("Can't write %s: %s", filename, err)
async def async_file_to_mem(self, key): async def async_file_to_mem(self, key: str) -> None:
"""Load voice from file cache into memory. """Load voice from file cache into memory.
This method is a coroutine. This method is a coroutine.
@ -432,7 +455,7 @@ class SpeechManager:
voice_file = os.path.join(self.cache_dir, filename) voice_file = os.path.join(self.cache_dir, filename)
def load_speech(): def load_speech() -> bytes:
"""Load a speech from filesystem.""" """Load a speech from filesystem."""
with open(voice_file, "rb") as speech: with open(voice_file, "rb") as speech:
return speech.read() return speech.read()
@ -446,18 +469,18 @@ class SpeechManager:
self._async_store_to_memcache(key, filename, data) self._async_store_to_memcache(key, filename, data)
@callback @callback
def _async_store_to_memcache(self, key, filename, data): def _async_store_to_memcache(self, key: str, filename: str, data: bytes) -> None:
"""Store data to memcache and set timer to remove it.""" """Store data to memcache and set timer to remove it."""
self.mem_cache[key] = {MEM_CACHE_FILENAME: filename, MEM_CACHE_VOICE: data} self.mem_cache[key] = {MEM_CACHE_FILENAME: filename, MEM_CACHE_VOICE: data}
@callback @callback
def async_remove_from_mem(): def async_remove_from_mem() -> None:
"""Cleanup memcache.""" """Cleanup memcache."""
self.mem_cache.pop(key, None) self.mem_cache.pop(key, None)
self.hass.loop.call_later(self.time_memory, async_remove_from_mem) self.hass.loop.call_later(self.time_memory, async_remove_from_mem)
async def async_read_tts(self, filename): async def async_read_tts(self, filename: str) -> tuple[str | None, bytes]:
"""Read a voice file and return binary. """Read a voice file and return binary.
This method is a coroutine. This method is a coroutine.
@ -475,10 +498,17 @@ class SpeechManager:
await self.async_file_to_mem(key) await self.async_file_to_mem(key)
content, _ = mimetypes.guess_type(filename) content, _ = mimetypes.guess_type(filename)
return content, self.mem_cache[key][MEM_CACHE_VOICE] return content, cast(bytes, self.mem_cache[key][MEM_CACHE_VOICE])
@staticmethod @staticmethod
def write_tags(filename, data, provider, message, language, options): def write_tags(
filename: str,
data: bytes,
provider: Provider,
message: str,
language: str,
options: dict | None,
) -> bytes:
"""Write ID3 tags to file. """Write ID3 tags to file.
Async friendly. Async friendly.
@ -491,8 +521,8 @@ class SpeechManager:
album = provider.name album = provider.name
artist = language artist = language
if options is not None and options.get("voice") is not None: if options is not None and (voice := options.get("voice")) is not None:
artist = options.get("voice") artist = voice
try: try:
tts_file = mutagen.File(data_bytes) tts_file = mutagen.File(data_bytes)
@ -540,21 +570,27 @@ class Provider:
"""Return a dict include default options.""" """Return a dict include default options."""
return None return None
def get_tts_audio(self, message, language, options=None): def get_tts_audio(
self, message: str, language: str, options: dict | None = None
) -> TtsAudioType:
"""Load tts audio file from provider.""" """Load tts audio file from provider."""
raise NotImplementedError() raise NotImplementedError()
async def async_get_tts_audio(self, message, language, options=None): async def async_get_tts_audio(
self, message: str, language: str, options: dict | None = None
) -> TtsAudioType:
"""Load tts audio file from provider. """Load tts audio file from provider.
Return a tuple of file extension and data as bytes. Return a tuple of file extension and data as bytes.
""" """
if TYPE_CHECKING:
assert self.hass
return await self.hass.async_add_executor_job( return await self.hass.async_add_executor_job(
ft.partial(self.get_tts_audio, message, language, options=options) ft.partial(self.get_tts_audio, message, language, options=options)
) )
def _init_tts_cache_dir(hass, cache_dir): def _init_tts_cache_dir(hass: HomeAssistant, cache_dir: str) -> str:
"""Init cache folder.""" """Init cache folder."""
if not os.path.isabs(cache_dir): if not os.path.isabs(cache_dir):
cache_dir = hass.config.path(cache_dir) cache_dir = hass.config.path(cache_dir)
@ -564,7 +600,7 @@ def _init_tts_cache_dir(hass, cache_dir):
return cache_dir return cache_dir
def _get_cache_files(cache_dir): def _get_cache_files(cache_dir: str) -> dict[str, str]:
"""Return a dict of given engine files.""" """Return a dict of given engine files."""
cache = {} cache = {}
@ -585,7 +621,7 @@ class TextToSpeechUrlView(HomeAssistantView):
url = "/api/tts_get_url" url = "/api/tts_get_url"
name = "api:tts:geturl" name = "api:tts:geturl"
def __init__(self, tts): def __init__(self, tts: SpeechManager) -> None:
"""Initialize a tts view.""" """Initialize a tts view."""
self.tts = tts self.tts = tts
@ -627,7 +663,7 @@ class TextToSpeechView(HomeAssistantView):
url = "/api/tts_proxy/{filename}" url = "/api/tts_proxy/{filename}"
name = "api:tts_speech" name = "api:tts_speech"
def __init__(self, tts): def __init__(self, tts: SpeechManager) -> None:
"""Initialize a tts view.""" """Initialize a tts view."""
self.tts = tts self.tts = tts