Add type hints to tts (#64050)

* Add type hint for _get_cache_files

* Add type hint for _init_tts_cache_dir

* Add init type hints for async_clear_cache

* Add type hints to async_setup_platform

* Add type hints to async_register_engine

* Add type hints to self.providers

* Add type hints to _async_store_to_memcache

* Add type hints to async_file_to_mem

* Add full type hints

* Use tuple in async_read_tts

Co-authored-by: epenet <epenet@users.noreply.github.com>
This commit is contained in:
epenet 2022-01-14 12:35:29 +01:00 committed by GitHub
parent 44a686931e
commit c1692a324b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -10,7 +10,7 @@ import logging
import mimetypes
import os
import re
from typing import Optional, cast
from typing import TYPE_CHECKING, Optional, cast
from aiohttp import web
import mutagen
@ -38,6 +38,7 @@ from homeassistant.helpers import config_per_platform, discovery
import homeassistant.helpers.config_validation as cv
from homeassistant.helpers.network import get_url
from homeassistant.helpers.service import async_set_service_schema
from homeassistant.helpers.typing import ConfigType, DiscoveryInfoType
from homeassistant.loader import async_get_integration
from homeassistant.setup import async_prepare_setup_platform
from homeassistant.util.yaml import load_yaml
@ -117,7 +118,7 @@ SCHEMA_SERVICE_SAY = vol.Schema(
SCHEMA_SERVICE_CLEAR_CACHE = vol.Schema({})
async def async_setup(hass, config):
async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
"""Set up TTS."""
tts = SpeechManager(hass)
@ -144,7 +145,11 @@ async def async_setup(hass, config):
dict, await hass.async_add_executor_job(load_yaml, str(services_yaml))
)
async def async_setup_platform(p_type, p_config=None, discovery_info=None):
async def async_setup_platform(
p_type: str,
p_config: ConfigType | None = None,
discovery_info: DiscoveryInfoType | None = None,
) -> None:
"""Set up a TTS platform."""
if p_config is None:
p_config = {}
@ -175,7 +180,7 @@ async def async_setup(hass, config):
async def async_say_handle(service: ServiceCall) -> None:
"""Service handle for say."""
entity_ids = service.data[ATTR_ENTITY_ID]
message = service.data.get(ATTR_MESSAGE)
message = service.data[ATTR_MESSAGE]
cache = service.data.get(ATTR_CACHE)
language = service.data.get(ATTR_LANGUAGE)
options = service.data.get(ATTR_OPTIONS)
@ -221,6 +226,7 @@ async def async_setup(hass, config):
setup_tasks = [
asyncio.create_task(async_setup_platform(p_type, p_config))
for p_type, p_config in config_per_platform(config, DOMAIN)
if p_type is not None
]
if setup_tasks:
@ -259,19 +265,21 @@ def _hash_options(options: dict) -> str:
class SpeechManager:
"""Representation of a speech store."""
def __init__(self, hass):
def __init__(self, hass: HomeAssistant) -> None:
"""Initialize a speech store."""
self.hass = hass
self.providers = {}
self.providers: dict[str, Provider] = {}
self.use_cache = DEFAULT_CACHE
self.cache_dir = DEFAULT_CACHE_DIR
self.time_memory = DEFAULT_TIME_MEMORY
self.base_url = None
self.file_cache = {}
self.mem_cache = {}
self.base_url: str | None = None
self.file_cache: dict[str, str] = {}
self.mem_cache: dict[str, dict[str, str | bytes]] = {}
async def async_init_cache(self, use_cache, cache_dir, time_memory, base_url):
async def async_init_cache(
self, use_cache: bool, cache_dir: str, time_memory: int, base_url: str | None
) -> None:
"""Init config folder and load file cache."""
self.use_cache = use_cache
self.time_memory = time_memory
@ -294,7 +302,7 @@ class SpeechManager:
if cache_files:
self.file_cache.update(cache_files)
async def async_clear_cache(self):
async def async_clear_cache(self) -> None:
"""Read file cache and delete files."""
self.mem_cache = {}
@ -310,7 +318,9 @@ class SpeechManager:
self.file_cache = {}
@callback
def async_register_engine(self, engine, provider, config):
def async_register_engine(
self, engine: str, provider: Provider, config: ConfigType
) -> None:
"""Register a TTS provider."""
provider.hass = self.hass
if provider.name is None:
@ -322,8 +332,13 @@ class SpeechManager:
)
async def async_get_url_path(
self, engine, message, cache=None, language=None, options=None
):
self,
engine: str,
message: str,
cache: bool | None = None,
language: str | None = None,
options: dict | None = None,
) -> str:
"""Get URL for play message.
This method is a coroutine.
@ -362,7 +377,7 @@ class SpeechManager:
# Is speech already in memory
if key in self.mem_cache:
filename = self.mem_cache[key][MEM_CACHE_FILENAME]
filename = cast(str, self.mem_cache[key][MEM_CACHE_FILENAME])
# Is file store in file cache
elif use_cache and key in self.file_cache:
filename = self.file_cache[key]
@ -375,7 +390,15 @@ class SpeechManager:
return f"/api/tts_proxy/{filename}"
async def async_get_tts_audio(self, engine, key, message, cache, language, options):
async def async_get_tts_audio(
self,
engine: str,
key: str,
message: str,
cache: bool,
language: str,
options: dict | None,
) -> str:
"""Receive TTS and store for view in cache.
This method is a coroutine.
@ -404,14 +427,14 @@ class SpeechManager:
return filename
async def async_save_tts_audio(self, key, filename, data):
async def async_save_tts_audio(self, key: str, filename: str, data: bytes) -> None:
"""Store voice data to file and file_cache.
This method is a coroutine.
"""
voice_file = os.path.join(self.cache_dir, filename)
def save_speech():
def save_speech() -> None:
"""Store speech to filesystem."""
with open(voice_file, "wb") as speech:
speech.write(data)
@ -422,7 +445,7 @@ class SpeechManager:
except OSError as err:
_LOGGER.error("Can't write %s: %s", filename, err)
async def async_file_to_mem(self, key):
async def async_file_to_mem(self, key: str) -> None:
"""Load voice from file cache into memory.
This method is a coroutine.
@ -432,7 +455,7 @@ class SpeechManager:
voice_file = os.path.join(self.cache_dir, filename)
def load_speech():
def load_speech() -> bytes:
"""Load a speech from filesystem."""
with open(voice_file, "rb") as speech:
return speech.read()
@ -446,18 +469,18 @@ class SpeechManager:
self._async_store_to_memcache(key, filename, data)
@callback
def _async_store_to_memcache(self, key, filename, data):
def _async_store_to_memcache(self, key: str, filename: str, data: bytes) -> None:
"""Store data to memcache and set timer to remove it."""
self.mem_cache[key] = {MEM_CACHE_FILENAME: filename, MEM_CACHE_VOICE: data}
@callback
def async_remove_from_mem():
def async_remove_from_mem() -> None:
"""Cleanup memcache."""
self.mem_cache.pop(key, None)
self.hass.loop.call_later(self.time_memory, async_remove_from_mem)
async def async_read_tts(self, filename):
async def async_read_tts(self, filename: str) -> tuple[str | None, bytes]:
"""Read a voice file and return binary.
This method is a coroutine.
@ -475,10 +498,17 @@ class SpeechManager:
await self.async_file_to_mem(key)
content, _ = mimetypes.guess_type(filename)
return content, self.mem_cache[key][MEM_CACHE_VOICE]
return content, cast(bytes, self.mem_cache[key][MEM_CACHE_VOICE])
@staticmethod
def write_tags(filename, data, provider, message, language, options):
def write_tags(
filename: str,
data: bytes,
provider: Provider,
message: str,
language: str,
options: dict | None,
) -> bytes:
"""Write ID3 tags to file.
Async friendly.
@ -491,8 +521,8 @@ class SpeechManager:
album = provider.name
artist = language
if options is not None and options.get("voice") is not None:
artist = options.get("voice")
if options is not None and (voice := options.get("voice")) is not None:
artist = voice
try:
tts_file = mutagen.File(data_bytes)
@ -540,21 +570,27 @@ class Provider:
"""Return a dict include default options."""
return None
def get_tts_audio(self, message, language, options=None):
def get_tts_audio(
self, message: str, language: str, options: dict | None = None
) -> TtsAudioType:
"""Load tts audio file from provider."""
raise NotImplementedError()
async def async_get_tts_audio(self, message, language, options=None):
async def async_get_tts_audio(
self, message: str, language: str, options: dict | None = None
) -> TtsAudioType:
"""Load tts audio file from provider.
Return a tuple of file extension and data as bytes.
"""
if TYPE_CHECKING:
assert self.hass
return await self.hass.async_add_executor_job(
ft.partial(self.get_tts_audio, message, language, options=options)
)
def _init_tts_cache_dir(hass, cache_dir):
def _init_tts_cache_dir(hass: HomeAssistant, cache_dir: str) -> str:
"""Init cache folder."""
if not os.path.isabs(cache_dir):
cache_dir = hass.config.path(cache_dir)
@ -564,7 +600,7 @@ def _init_tts_cache_dir(hass, cache_dir):
return cache_dir
def _get_cache_files(cache_dir):
def _get_cache_files(cache_dir: str) -> dict[str, str]:
"""Return a dict of given engine files."""
cache = {}
@ -585,7 +621,7 @@ class TextToSpeechUrlView(HomeAssistantView):
url = "/api/tts_get_url"
name = "api:tts:geturl"
def __init__(self, tts):
def __init__(self, tts: SpeechManager) -> None:
"""Initialize a tts view."""
self.tts = tts
@ -627,7 +663,7 @@ class TextToSpeechView(HomeAssistantView):
url = "/api/tts_proxy/{filename}"
name = "api:tts_speech"
def __init__(self, tts):
def __init__(self, tts: SpeechManager) -> None:
"""Initialize a tts view."""
self.tts = tts