mirror of
https://github.com/home-assistant/core.git
synced 2025-07-21 12:17:07 +00:00
Add type hints to tts (#64050)
* Add type hint for _get_cache_files * Add type hint for _init_tts_cache_dir * Add init type hints for async_clear_cache * Add type hints to async_setup_platform * Add type hints to async_register_engine * Add type hints to self.providers * Add type hints to _async_store_to_memcache * Add type hints to async_file_to_mem * Add full type hints * Use tuple in async_read_tts Co-authored-by: epenet <epenet@users.noreply.github.com>
This commit is contained in:
parent
44a686931e
commit
c1692a324b
@ -10,7 +10,7 @@ import logging
|
||||
import mimetypes
|
||||
import os
|
||||
import re
|
||||
from typing import Optional, cast
|
||||
from typing import TYPE_CHECKING, Optional, cast
|
||||
|
||||
from aiohttp import web
|
||||
import mutagen
|
||||
@ -38,6 +38,7 @@ from homeassistant.helpers import config_per_platform, discovery
|
||||
import homeassistant.helpers.config_validation as cv
|
||||
from homeassistant.helpers.network import get_url
|
||||
from homeassistant.helpers.service import async_set_service_schema
|
||||
from homeassistant.helpers.typing import ConfigType, DiscoveryInfoType
|
||||
from homeassistant.loader import async_get_integration
|
||||
from homeassistant.setup import async_prepare_setup_platform
|
||||
from homeassistant.util.yaml import load_yaml
|
||||
@ -117,7 +118,7 @@ SCHEMA_SERVICE_SAY = vol.Schema(
|
||||
SCHEMA_SERVICE_CLEAR_CACHE = vol.Schema({})
|
||||
|
||||
|
||||
async def async_setup(hass, config):
|
||||
async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
|
||||
"""Set up TTS."""
|
||||
tts = SpeechManager(hass)
|
||||
|
||||
@ -144,7 +145,11 @@ async def async_setup(hass, config):
|
||||
dict, await hass.async_add_executor_job(load_yaml, str(services_yaml))
|
||||
)
|
||||
|
||||
async def async_setup_platform(p_type, p_config=None, discovery_info=None):
|
||||
async def async_setup_platform(
|
||||
p_type: str,
|
||||
p_config: ConfigType | None = None,
|
||||
discovery_info: DiscoveryInfoType | None = None,
|
||||
) -> None:
|
||||
"""Set up a TTS platform."""
|
||||
if p_config is None:
|
||||
p_config = {}
|
||||
@ -175,7 +180,7 @@ async def async_setup(hass, config):
|
||||
async def async_say_handle(service: ServiceCall) -> None:
|
||||
"""Service handle for say."""
|
||||
entity_ids = service.data[ATTR_ENTITY_ID]
|
||||
message = service.data.get(ATTR_MESSAGE)
|
||||
message = service.data[ATTR_MESSAGE]
|
||||
cache = service.data.get(ATTR_CACHE)
|
||||
language = service.data.get(ATTR_LANGUAGE)
|
||||
options = service.data.get(ATTR_OPTIONS)
|
||||
@ -221,6 +226,7 @@ async def async_setup(hass, config):
|
||||
setup_tasks = [
|
||||
asyncio.create_task(async_setup_platform(p_type, p_config))
|
||||
for p_type, p_config in config_per_platform(config, DOMAIN)
|
||||
if p_type is not None
|
||||
]
|
||||
|
||||
if setup_tasks:
|
||||
@ -259,19 +265,21 @@ def _hash_options(options: dict) -> str:
|
||||
class SpeechManager:
|
||||
"""Representation of a speech store."""
|
||||
|
||||
def __init__(self, hass):
|
||||
def __init__(self, hass: HomeAssistant) -> None:
|
||||
"""Initialize a speech store."""
|
||||
self.hass = hass
|
||||
self.providers = {}
|
||||
self.providers: dict[str, Provider] = {}
|
||||
|
||||
self.use_cache = DEFAULT_CACHE
|
||||
self.cache_dir = DEFAULT_CACHE_DIR
|
||||
self.time_memory = DEFAULT_TIME_MEMORY
|
||||
self.base_url = None
|
||||
self.file_cache = {}
|
||||
self.mem_cache = {}
|
||||
self.base_url: str | None = None
|
||||
self.file_cache: dict[str, str] = {}
|
||||
self.mem_cache: dict[str, dict[str, str | bytes]] = {}
|
||||
|
||||
async def async_init_cache(self, use_cache, cache_dir, time_memory, base_url):
|
||||
async def async_init_cache(
|
||||
self, use_cache: bool, cache_dir: str, time_memory: int, base_url: str | None
|
||||
) -> None:
|
||||
"""Init config folder and load file cache."""
|
||||
self.use_cache = use_cache
|
||||
self.time_memory = time_memory
|
||||
@ -294,7 +302,7 @@ class SpeechManager:
|
||||
if cache_files:
|
||||
self.file_cache.update(cache_files)
|
||||
|
||||
async def async_clear_cache(self):
|
||||
async def async_clear_cache(self) -> None:
|
||||
"""Read file cache and delete files."""
|
||||
self.mem_cache = {}
|
||||
|
||||
@ -310,7 +318,9 @@ class SpeechManager:
|
||||
self.file_cache = {}
|
||||
|
||||
@callback
|
||||
def async_register_engine(self, engine, provider, config):
|
||||
def async_register_engine(
|
||||
self, engine: str, provider: Provider, config: ConfigType
|
||||
) -> None:
|
||||
"""Register a TTS provider."""
|
||||
provider.hass = self.hass
|
||||
if provider.name is None:
|
||||
@ -322,8 +332,13 @@ class SpeechManager:
|
||||
)
|
||||
|
||||
async def async_get_url_path(
|
||||
self, engine, message, cache=None, language=None, options=None
|
||||
):
|
||||
self,
|
||||
engine: str,
|
||||
message: str,
|
||||
cache: bool | None = None,
|
||||
language: str | None = None,
|
||||
options: dict | None = None,
|
||||
) -> str:
|
||||
"""Get URL for play message.
|
||||
|
||||
This method is a coroutine.
|
||||
@ -362,7 +377,7 @@ class SpeechManager:
|
||||
|
||||
# Is speech already in memory
|
||||
if key in self.mem_cache:
|
||||
filename = self.mem_cache[key][MEM_CACHE_FILENAME]
|
||||
filename = cast(str, self.mem_cache[key][MEM_CACHE_FILENAME])
|
||||
# Is file store in file cache
|
||||
elif use_cache and key in self.file_cache:
|
||||
filename = self.file_cache[key]
|
||||
@ -375,7 +390,15 @@ class SpeechManager:
|
||||
|
||||
return f"/api/tts_proxy/{filename}"
|
||||
|
||||
async def async_get_tts_audio(self, engine, key, message, cache, language, options):
|
||||
async def async_get_tts_audio(
|
||||
self,
|
||||
engine: str,
|
||||
key: str,
|
||||
message: str,
|
||||
cache: bool,
|
||||
language: str,
|
||||
options: dict | None,
|
||||
) -> str:
|
||||
"""Receive TTS and store for view in cache.
|
||||
|
||||
This method is a coroutine.
|
||||
@ -404,14 +427,14 @@ class SpeechManager:
|
||||
|
||||
return filename
|
||||
|
||||
async def async_save_tts_audio(self, key, filename, data):
|
||||
async def async_save_tts_audio(self, key: str, filename: str, data: bytes) -> None:
|
||||
"""Store voice data to file and file_cache.
|
||||
|
||||
This method is a coroutine.
|
||||
"""
|
||||
voice_file = os.path.join(self.cache_dir, filename)
|
||||
|
||||
def save_speech():
|
||||
def save_speech() -> None:
|
||||
"""Store speech to filesystem."""
|
||||
with open(voice_file, "wb") as speech:
|
||||
speech.write(data)
|
||||
@ -422,7 +445,7 @@ class SpeechManager:
|
||||
except OSError as err:
|
||||
_LOGGER.error("Can't write %s: %s", filename, err)
|
||||
|
||||
async def async_file_to_mem(self, key):
|
||||
async def async_file_to_mem(self, key: str) -> None:
|
||||
"""Load voice from file cache into memory.
|
||||
|
||||
This method is a coroutine.
|
||||
@ -432,7 +455,7 @@ class SpeechManager:
|
||||
|
||||
voice_file = os.path.join(self.cache_dir, filename)
|
||||
|
||||
def load_speech():
|
||||
def load_speech() -> bytes:
|
||||
"""Load a speech from filesystem."""
|
||||
with open(voice_file, "rb") as speech:
|
||||
return speech.read()
|
||||
@ -446,18 +469,18 @@ class SpeechManager:
|
||||
self._async_store_to_memcache(key, filename, data)
|
||||
|
||||
@callback
|
||||
def _async_store_to_memcache(self, key, filename, data):
|
||||
def _async_store_to_memcache(self, key: str, filename: str, data: bytes) -> None:
|
||||
"""Store data to memcache and set timer to remove it."""
|
||||
self.mem_cache[key] = {MEM_CACHE_FILENAME: filename, MEM_CACHE_VOICE: data}
|
||||
|
||||
@callback
|
||||
def async_remove_from_mem():
|
||||
def async_remove_from_mem() -> None:
|
||||
"""Cleanup memcache."""
|
||||
self.mem_cache.pop(key, None)
|
||||
|
||||
self.hass.loop.call_later(self.time_memory, async_remove_from_mem)
|
||||
|
||||
async def async_read_tts(self, filename):
|
||||
async def async_read_tts(self, filename: str) -> tuple[str | None, bytes]:
|
||||
"""Read a voice file and return binary.
|
||||
|
||||
This method is a coroutine.
|
||||
@ -475,10 +498,17 @@ class SpeechManager:
|
||||
await self.async_file_to_mem(key)
|
||||
|
||||
content, _ = mimetypes.guess_type(filename)
|
||||
return content, self.mem_cache[key][MEM_CACHE_VOICE]
|
||||
return content, cast(bytes, self.mem_cache[key][MEM_CACHE_VOICE])
|
||||
|
||||
@staticmethod
|
||||
def write_tags(filename, data, provider, message, language, options):
|
||||
def write_tags(
|
||||
filename: str,
|
||||
data: bytes,
|
||||
provider: Provider,
|
||||
message: str,
|
||||
language: str,
|
||||
options: dict | None,
|
||||
) -> bytes:
|
||||
"""Write ID3 tags to file.
|
||||
|
||||
Async friendly.
|
||||
@ -491,8 +521,8 @@ class SpeechManager:
|
||||
album = provider.name
|
||||
artist = language
|
||||
|
||||
if options is not None and options.get("voice") is not None:
|
||||
artist = options.get("voice")
|
||||
if options is not None and (voice := options.get("voice")) is not None:
|
||||
artist = voice
|
||||
|
||||
try:
|
||||
tts_file = mutagen.File(data_bytes)
|
||||
@ -540,21 +570,27 @@ class Provider:
|
||||
"""Return a dict include default options."""
|
||||
return None
|
||||
|
||||
def get_tts_audio(self, message, language, options=None):
|
||||
def get_tts_audio(
|
||||
self, message: str, language: str, options: dict | None = None
|
||||
) -> TtsAudioType:
|
||||
"""Load tts audio file from provider."""
|
||||
raise NotImplementedError()
|
||||
|
||||
async def async_get_tts_audio(self, message, language, options=None):
|
||||
async def async_get_tts_audio(
|
||||
self, message: str, language: str, options: dict | None = None
|
||||
) -> TtsAudioType:
|
||||
"""Load tts audio file from provider.
|
||||
|
||||
Return a tuple of file extension and data as bytes.
|
||||
"""
|
||||
if TYPE_CHECKING:
|
||||
assert self.hass
|
||||
return await self.hass.async_add_executor_job(
|
||||
ft.partial(self.get_tts_audio, message, language, options=options)
|
||||
)
|
||||
|
||||
|
||||
def _init_tts_cache_dir(hass, cache_dir):
|
||||
def _init_tts_cache_dir(hass: HomeAssistant, cache_dir: str) -> str:
|
||||
"""Init cache folder."""
|
||||
if not os.path.isabs(cache_dir):
|
||||
cache_dir = hass.config.path(cache_dir)
|
||||
@ -564,7 +600,7 @@ def _init_tts_cache_dir(hass, cache_dir):
|
||||
return cache_dir
|
||||
|
||||
|
||||
def _get_cache_files(cache_dir):
|
||||
def _get_cache_files(cache_dir: str) -> dict[str, str]:
|
||||
"""Return a dict of given engine files."""
|
||||
cache = {}
|
||||
|
||||
@ -585,7 +621,7 @@ class TextToSpeechUrlView(HomeAssistantView):
|
||||
url = "/api/tts_get_url"
|
||||
name = "api:tts:geturl"
|
||||
|
||||
def __init__(self, tts):
|
||||
def __init__(self, tts: SpeechManager) -> None:
|
||||
"""Initialize a tts view."""
|
||||
self.tts = tts
|
||||
|
||||
@ -627,7 +663,7 @@ class TextToSpeechView(HomeAssistantView):
|
||||
url = "/api/tts_proxy/{filename}"
|
||||
name = "api:tts_speech"
|
||||
|
||||
def __init__(self, tts):
|
||||
def __init__(self, tts: SpeechManager) -> None:
|
||||
"""Initialize a tts view."""
|
||||
self.tts = tts
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user