mirror of
https://github.com/home-assistant/core.git
synced 2025-07-23 13:17:32 +00:00
Add type hints to TTS provider (#78285)
Co-authored-by: Marc Mueller <30130371+cdce8p@users.noreply.github.com>
This commit is contained in:
parent
9d47160e68
commit
55e59b778c
@ -11,7 +11,7 @@ import mimetypes
|
|||||||
import os
|
import os
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
import re
|
import re
|
||||||
from typing import TYPE_CHECKING, Optional, cast
|
from typing import TYPE_CHECKING, Any, Optional, cast
|
||||||
|
|
||||||
from aiohttp import web
|
from aiohttp import web
|
||||||
import mutagen
|
import mutagen
|
||||||
@ -49,8 +49,6 @@ from homeassistant.util.yaml import load_yaml
|
|||||||
|
|
||||||
from .const import DOMAIN
|
from .const import DOMAIN
|
||||||
|
|
||||||
# mypy: allow-untyped-defs, no-check-untyped-defs
|
|
||||||
|
|
||||||
_LOGGER = logging.getLogger(__name__)
|
_LOGGER = logging.getLogger(__name__)
|
||||||
|
|
||||||
TtsAudioType = tuple[Optional[str], Optional[bytes]]
|
TtsAudioType = tuple[Optional[str], Optional[bytes]]
|
||||||
@ -86,7 +84,7 @@ _RE_VOICE_FILE = re.compile(r"([a-f0-9]{40})_([^_]+)_([^_]+)_([a-z_]+)\.[a-z0-9]
|
|||||||
KEY_PATTERN = "{0}_{1}_{2}_{3}"
|
KEY_PATTERN = "{0}_{1}_{2}_{3}"
|
||||||
|
|
||||||
|
|
||||||
def _deprecated_platform(value):
|
def _deprecated_platform(value: str) -> str:
|
||||||
"""Validate if platform is deprecated."""
|
"""Validate if platform is deprecated."""
|
||||||
if value == "google":
|
if value == "google":
|
||||||
raise vol.Invalid(
|
raise vol.Invalid(
|
||||||
@ -253,7 +251,9 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
|
|||||||
if setup_tasks:
|
if setup_tasks:
|
||||||
await asyncio.wait(setup_tasks)
|
await asyncio.wait(setup_tasks)
|
||||||
|
|
||||||
async def async_platform_discovered(platform, info):
|
async def async_platform_discovered(
|
||||||
|
platform: str, info: dict[str, Any] | None
|
||||||
|
) -> None:
|
||||||
"""Handle for discovered platform."""
|
"""Handle for discovered platform."""
|
||||||
await async_setup_platform(platform, discovery_info=info)
|
await async_setup_platform(platform, discovery_info=info)
|
||||||
|
|
||||||
@ -327,7 +327,7 @@ class SpeechManager:
|
|||||||
"""Read file cache and delete files."""
|
"""Read file cache and delete files."""
|
||||||
self.mem_cache = {}
|
self.mem_cache = {}
|
||||||
|
|
||||||
def remove_files():
|
def remove_files() -> None:
|
||||||
"""Remove files from filesystem."""
|
"""Remove files from filesystem."""
|
||||||
for filename in self.file_cache.values():
|
for filename in self.file_cache.values():
|
||||||
try:
|
try:
|
||||||
@ -365,7 +365,11 @@ class SpeechManager:
|
|||||||
|
|
||||||
# Languages
|
# Languages
|
||||||
language = language or provider.default_language
|
language = language or provider.default_language
|
||||||
if language is None or language not in provider.supported_languages:
|
if (
|
||||||
|
language is None
|
||||||
|
or provider.supported_languages is None
|
||||||
|
or language not in provider.supported_languages
|
||||||
|
):
|
||||||
raise HomeAssistantError(f"Not supported language {language}")
|
raise HomeAssistantError(f"Not supported language {language}")
|
||||||
|
|
||||||
# Options
|
# Options
|
||||||
@ -583,33 +587,33 @@ class Provider:
|
|||||||
name: str | None = None
|
name: str | None = None
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def default_language(self):
|
def default_language(self) -> str | None:
|
||||||
"""Return the default language."""
|
"""Return the default language."""
|
||||||
return None
|
return None
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def supported_languages(self):
|
def supported_languages(self) -> list[str] | None:
|
||||||
"""Return a list of supported languages."""
|
"""Return a list of supported languages."""
|
||||||
return None
|
return None
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def supported_options(self):
|
def supported_options(self) -> list[str] | None:
|
||||||
"""Return a list of supported options like voice, emotionen."""
|
"""Return a list of supported options like voice, emotions."""
|
||||||
return None
|
return None
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def default_options(self):
|
def default_options(self) -> dict[str, Any] | None:
|
||||||
"""Return a dict include default options."""
|
"""Return a dict include default options."""
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def get_tts_audio(
|
def get_tts_audio(
|
||||||
self, message: str, language: str, options: dict | None = None
|
self, message: str, language: str, options: dict[str, Any] | None = None
|
||||||
) -> TtsAudioType:
|
) -> TtsAudioType:
|
||||||
"""Load tts audio file from provider."""
|
"""Load tts audio file from provider."""
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
||||||
async def async_get_tts_audio(
|
async def async_get_tts_audio(
|
||||||
self, message: str, language: str, options: dict | None = None
|
self, message: str, language: str, options: dict[str, Any] | None = None
|
||||||
) -> TtsAudioType:
|
) -> TtsAudioType:
|
||||||
"""Load tts audio file from provider.
|
"""Load tts audio file from provider.
|
||||||
|
|
||||||
|
@ -1,20 +1,22 @@
|
|||||||
"""Support notifications through TTS service."""
|
"""Support notifications through TTS service."""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
import voluptuous as vol
|
import voluptuous as vol
|
||||||
|
|
||||||
from homeassistant.components.notify import PLATFORM_SCHEMA, BaseNotificationService
|
from homeassistant.components.notify import PLATFORM_SCHEMA, BaseNotificationService
|
||||||
from homeassistant.const import ATTR_ENTITY_ID, CONF_NAME
|
from homeassistant.const import ATTR_ENTITY_ID, CONF_NAME
|
||||||
from homeassistant.core import split_entity_id
|
from homeassistant.core import HomeAssistant, split_entity_id
|
||||||
import homeassistant.helpers.config_validation as cv
|
import homeassistant.helpers.config_validation as cv
|
||||||
|
from homeassistant.helpers.typing import ConfigType, DiscoveryInfoType
|
||||||
|
|
||||||
from . import ATTR_LANGUAGE, ATTR_MESSAGE, DOMAIN
|
from . import ATTR_LANGUAGE, ATTR_MESSAGE, DOMAIN
|
||||||
|
|
||||||
CONF_MEDIA_PLAYER = "media_player"
|
CONF_MEDIA_PLAYER = "media_player"
|
||||||
CONF_TTS_SERVICE = "tts_service"
|
CONF_TTS_SERVICE = "tts_service"
|
||||||
|
|
||||||
# mypy: allow-untyped-calls, allow-untyped-defs, no-check-untyped-defs
|
|
||||||
|
|
||||||
_LOGGER = logging.getLogger(__name__)
|
_LOGGER = logging.getLogger(__name__)
|
||||||
|
|
||||||
PLATFORM_SCHEMA = PLATFORM_SCHEMA.extend(
|
PLATFORM_SCHEMA = PLATFORM_SCHEMA.extend(
|
||||||
@ -27,7 +29,11 @@ PLATFORM_SCHEMA = PLATFORM_SCHEMA.extend(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
async def async_get_service(hass, config, discovery_info=None):
|
async def async_get_service(
|
||||||
|
hass: HomeAssistant,
|
||||||
|
config: ConfigType,
|
||||||
|
discovery_info: DiscoveryInfoType | None = None,
|
||||||
|
) -> TTSNotificationService:
|
||||||
"""Return the notify service."""
|
"""Return the notify service."""
|
||||||
|
|
||||||
return TTSNotificationService(config)
|
return TTSNotificationService(config)
|
||||||
@ -36,13 +42,13 @@ async def async_get_service(hass, config, discovery_info=None):
|
|||||||
class TTSNotificationService(BaseNotificationService):
|
class TTSNotificationService(BaseNotificationService):
|
||||||
"""The TTS Notification Service."""
|
"""The TTS Notification Service."""
|
||||||
|
|
||||||
def __init__(self, config):
|
def __init__(self, config: ConfigType) -> None:
|
||||||
"""Initialize the service."""
|
"""Initialize the service."""
|
||||||
_, self._tts_service = split_entity_id(config[CONF_TTS_SERVICE])
|
_, self._tts_service = split_entity_id(config[CONF_TTS_SERVICE])
|
||||||
self._media_player = config[CONF_MEDIA_PLAYER]
|
self._media_player = config[CONF_MEDIA_PLAYER]
|
||||||
self._language = config.get(ATTR_LANGUAGE)
|
self._language = config.get(ATTR_LANGUAGE)
|
||||||
|
|
||||||
async def async_send_message(self, message="", **kwargs):
|
async def async_send_message(self, message: str = "", **kwargs: Any) -> None:
|
||||||
"""Call TTS service to speak the notification."""
|
"""Call TTS service to speak the notification."""
|
||||||
_LOGGER.debug("%s '%s' on %s", self._tts_service, message, self._media_player)
|
_LOGGER.debug("%s '%s' on %s", self._tts_service, message, self._media_player)
|
||||||
|
|
||||||
|
@ -2127,6 +2127,35 @@ _INHERITANCE_MATCH: dict[str, list[ClassTypeHintMatch]] = {
|
|||||||
],
|
],
|
||||||
),
|
),
|
||||||
],
|
],
|
||||||
|
"tts": [
|
||||||
|
ClassTypeHintMatch(
|
||||||
|
base_class="Provider",
|
||||||
|
matches=[
|
||||||
|
TypeHintMatch(
|
||||||
|
function_name="default_language",
|
||||||
|
return_type=["str", None],
|
||||||
|
),
|
||||||
|
TypeHintMatch(
|
||||||
|
function_name="supported_languages",
|
||||||
|
return_type=["list[str]", None],
|
||||||
|
),
|
||||||
|
TypeHintMatch(
|
||||||
|
function_name="supported_options",
|
||||||
|
return_type=["list[str]", None],
|
||||||
|
),
|
||||||
|
TypeHintMatch(
|
||||||
|
function_name="default_options",
|
||||||
|
return_type=["dict[str, Any]", None],
|
||||||
|
),
|
||||||
|
TypeHintMatch(
|
||||||
|
function_name="get_tts_audio",
|
||||||
|
arg_types={1: "str", 2: "str", 3: "dict[str, Any] | None"},
|
||||||
|
return_type="TtsAudioType",
|
||||||
|
has_async_counterpart=True,
|
||||||
|
),
|
||||||
|
],
|
||||||
|
),
|
||||||
|
],
|
||||||
"update": [
|
"update": [
|
||||||
ClassTypeHintMatch(
|
ClassTypeHintMatch(
|
||||||
base_class="Entity",
|
base_class="Entity",
|
||||||
|
Loading…
x
Reference in New Issue
Block a user