From c1692a324b64881ca9a31afc3f6f38f6c26d8b88 Mon Sep 17 00:00:00 2001 From: epenet <6771947+epenet@users.noreply.github.com> Date: Fri, 14 Jan 2022 12:35:29 +0100 Subject: [PATCH] Add type hints to tts (#64050) * Add type hint for _get_cache_files * Add type hint for _init_tts_cache_dir * Add init type hints for async_clear_cache * Add type hints to async_setup_platform * Add type hints to async_register_engine * Add type hints to self.providers * Add type hints to _async_store_to_memcache * Add type hints to async_file_to_mem * Add full type hints * Use tuple in async_read_tts Co-authored-by: epenet --- homeassistant/components/tts/__init__.py | 102 +++++++++++++++-------- 1 file changed, 69 insertions(+), 33 deletions(-) diff --git a/homeassistant/components/tts/__init__.py b/homeassistant/components/tts/__init__.py index a7ab7139aa0..807390f8992 100644 --- a/homeassistant/components/tts/__init__.py +++ b/homeassistant/components/tts/__init__.py @@ -10,7 +10,7 @@ import logging import mimetypes import os import re -from typing import Optional, cast +from typing import TYPE_CHECKING, Optional, cast from aiohttp import web import mutagen @@ -38,6 +38,7 @@ from homeassistant.helpers import config_per_platform, discovery import homeassistant.helpers.config_validation as cv from homeassistant.helpers.network import get_url from homeassistant.helpers.service import async_set_service_schema +from homeassistant.helpers.typing import ConfigType, DiscoveryInfoType from homeassistant.loader import async_get_integration from homeassistant.setup import async_prepare_setup_platform from homeassistant.util.yaml import load_yaml @@ -117,7 +118,7 @@ SCHEMA_SERVICE_SAY = vol.Schema( SCHEMA_SERVICE_CLEAR_CACHE = vol.Schema({}) -async def async_setup(hass, config): +async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool: """Set up TTS.""" tts = SpeechManager(hass) @@ -144,7 +145,11 @@ async def async_setup(hass, config): dict, await hass.async_add_executor_job(load_yaml, str(services_yaml)) ) - async def async_setup_platform(p_type, p_config=None, discovery_info=None): + async def async_setup_platform( + p_type: str, + p_config: ConfigType | None = None, + discovery_info: DiscoveryInfoType | None = None, + ) -> None: """Set up a TTS platform.""" if p_config is None: p_config = {} @@ -175,7 +180,7 @@ async def async_setup(hass, config): async def async_say_handle(service: ServiceCall) -> None: """Service handle for say.""" entity_ids = service.data[ATTR_ENTITY_ID] - message = service.data.get(ATTR_MESSAGE) + message = service.data[ATTR_MESSAGE] cache = service.data.get(ATTR_CACHE) language = service.data.get(ATTR_LANGUAGE) options = service.data.get(ATTR_OPTIONS) @@ -221,6 +226,7 @@ async def async_setup(hass, config): setup_tasks = [ asyncio.create_task(async_setup_platform(p_type, p_config)) for p_type, p_config in config_per_platform(config, DOMAIN) + if p_type is not None ] if setup_tasks: @@ -259,19 +265,21 @@ def _hash_options(options: dict) -> str: class SpeechManager: """Representation of a speech store.""" - def __init__(self, hass): + def __init__(self, hass: HomeAssistant) -> None: """Initialize a speech store.""" self.hass = hass - self.providers = {} + self.providers: dict[str, Provider] = {} self.use_cache = DEFAULT_CACHE self.cache_dir = DEFAULT_CACHE_DIR self.time_memory = DEFAULT_TIME_MEMORY - self.base_url = None - self.file_cache = {} - self.mem_cache = {} + self.base_url: str | None = None + self.file_cache: dict[str, str] = {} + self.mem_cache: dict[str, dict[str, str | bytes]] = {} - async def async_init_cache(self, use_cache, cache_dir, time_memory, base_url): + async def async_init_cache( + self, use_cache: bool, cache_dir: str, time_memory: int, base_url: str | None + ) -> None: """Init config folder and load file cache.""" self.use_cache = use_cache self.time_memory = time_memory @@ -294,7 +302,7 @@ class SpeechManager: if cache_files: self.file_cache.update(cache_files) - async def async_clear_cache(self): + async def async_clear_cache(self) -> None: """Read file cache and delete files.""" self.mem_cache = {} @@ -310,7 +318,9 @@ class SpeechManager: self.file_cache = {} @callback - def async_register_engine(self, engine, provider, config): + def async_register_engine( + self, engine: str, provider: Provider, config: ConfigType + ) -> None: """Register a TTS provider.""" provider.hass = self.hass if provider.name is None: @@ -322,8 +332,13 @@ class SpeechManager: ) async def async_get_url_path( - self, engine, message, cache=None, language=None, options=None - ): + self, + engine: str, + message: str, + cache: bool | None = None, + language: str | None = None, + options: dict | None = None, + ) -> str: """Get URL for play message. This method is a coroutine. @@ -362,7 +377,7 @@ class SpeechManager: # Is speech already in memory if key in self.mem_cache: - filename = self.mem_cache[key][MEM_CACHE_FILENAME] + filename = cast(str, self.mem_cache[key][MEM_CACHE_FILENAME]) # Is file store in file cache elif use_cache and key in self.file_cache: filename = self.file_cache[key] @@ -375,7 +390,15 @@ class SpeechManager: return f"/api/tts_proxy/{filename}" - async def async_get_tts_audio(self, engine, key, message, cache, language, options): + async def async_get_tts_audio( + self, + engine: str, + key: str, + message: str, + cache: bool, + language: str, + options: dict | None, + ) -> str: """Receive TTS and store for view in cache. This method is a coroutine. @@ -404,14 +427,14 @@ class SpeechManager: return filename - async def async_save_tts_audio(self, key, filename, data): + async def async_save_tts_audio(self, key: str, filename: str, data: bytes) -> None: """Store voice data to file and file_cache. This method is a coroutine. """ voice_file = os.path.join(self.cache_dir, filename) - def save_speech(): + def save_speech() -> None: """Store speech to filesystem.""" with open(voice_file, "wb") as speech: speech.write(data) @@ -422,7 +445,7 @@ class SpeechManager: except OSError as err: _LOGGER.error("Can't write %s: %s", filename, err) - async def async_file_to_mem(self, key): + async def async_file_to_mem(self, key: str) -> None: """Load voice from file cache into memory. This method is a coroutine. @@ -432,7 +455,7 @@ class SpeechManager: voice_file = os.path.join(self.cache_dir, filename) - def load_speech(): + def load_speech() -> bytes: """Load a speech from filesystem.""" with open(voice_file, "rb") as speech: return speech.read() @@ -446,18 +469,18 @@ class SpeechManager: self._async_store_to_memcache(key, filename, data) @callback - def _async_store_to_memcache(self, key, filename, data): + def _async_store_to_memcache(self, key: str, filename: str, data: bytes) -> None: """Store data to memcache and set timer to remove it.""" self.mem_cache[key] = {MEM_CACHE_FILENAME: filename, MEM_CACHE_VOICE: data} @callback - def async_remove_from_mem(): + def async_remove_from_mem() -> None: """Cleanup memcache.""" self.mem_cache.pop(key, None) self.hass.loop.call_later(self.time_memory, async_remove_from_mem) - async def async_read_tts(self, filename): + async def async_read_tts(self, filename: str) -> tuple[str | None, bytes]: """Read a voice file and return binary. This method is a coroutine. @@ -475,10 +498,17 @@ class SpeechManager: await self.async_file_to_mem(key) content, _ = mimetypes.guess_type(filename) - return content, self.mem_cache[key][MEM_CACHE_VOICE] + return content, cast(bytes, self.mem_cache[key][MEM_CACHE_VOICE]) @staticmethod - def write_tags(filename, data, provider, message, language, options): + def write_tags( + filename: str, + data: bytes, + provider: Provider, + message: str, + language: str, + options: dict | None, + ) -> bytes: """Write ID3 tags to file. Async friendly. @@ -491,8 +521,8 @@ class SpeechManager: album = provider.name artist = language - if options is not None and options.get("voice") is not None: - artist = options.get("voice") + if options is not None and (voice := options.get("voice")) is not None: + artist = voice try: tts_file = mutagen.File(data_bytes) @@ -540,21 +570,27 @@ class Provider: """Return a dict include default options.""" return None - def get_tts_audio(self, message, language, options=None): + def get_tts_audio( + self, message: str, language: str, options: dict | None = None + ) -> TtsAudioType: """Load tts audio file from provider.""" raise NotImplementedError() - async def async_get_tts_audio(self, message, language, options=None): + async def async_get_tts_audio( + self, message: str, language: str, options: dict | None = None + ) -> TtsAudioType: """Load tts audio file from provider. Return a tuple of file extension and data as bytes. """ + if TYPE_CHECKING: + assert self.hass return await self.hass.async_add_executor_job( ft.partial(self.get_tts_audio, message, language, options=options) ) -def _init_tts_cache_dir(hass, cache_dir): +def _init_tts_cache_dir(hass: HomeAssistant, cache_dir: str) -> str: """Init cache folder.""" if not os.path.isabs(cache_dir): cache_dir = hass.config.path(cache_dir) @@ -564,7 +600,7 @@ def _init_tts_cache_dir(hass, cache_dir): return cache_dir -def _get_cache_files(cache_dir): +def _get_cache_files(cache_dir: str) -> dict[str, str]: """Return a dict of given engine files.""" cache = {} @@ -585,7 +621,7 @@ class TextToSpeechUrlView(HomeAssistantView): url = "/api/tts_get_url" name = "api:tts:geturl" - def __init__(self, tts): + def __init__(self, tts: SpeechManager) -> None: """Initialize a tts view.""" self.tts = tts @@ -627,7 +663,7 @@ class TextToSpeechView(HomeAssistantView): url = "/api/tts_proxy/{filename}" name = "api:tts_speech" - def __init__(self, tts): + def __init__(self, tts: SpeechManager) -> None: """Initialize a tts view.""" self.tts = tts