Move TTS entity to own file (#139538)

* Move entity to own file

* Move entity tests
This commit is contained in:
Paulus Schoutsen 2025-02-28 19:40:13 +00:00 committed by GitHub
parent 455363871f
commit 1a80934593
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 310 additions and 292 deletions

View File

@ -3,10 +3,8 @@
from __future__ import annotations
import asyncio
from collections.abc import Mapping
from dataclasses import dataclass
from datetime import datetime
from functools import partial
import hashlib
from http import HTTPStatus
import io
@ -18,7 +16,7 @@ import secrets
import subprocess
import tempfile
from time import monotonic
from typing import Any, Final, TypedDict, final
from typing import Any, Final, TypedDict
from aiohttp import web
import mutagen
@ -28,22 +26,8 @@ import voluptuous as vol
from homeassistant.components import ffmpeg, websocket_api
from homeassistant.components.http import HomeAssistantView
from homeassistant.components.media_player import (
ATTR_MEDIA_ANNOUNCE,
ATTR_MEDIA_CONTENT_ID,
ATTR_MEDIA_CONTENT_TYPE,
DOMAIN as DOMAIN_MP,
SERVICE_PLAY_MEDIA,
MediaType,
)
from homeassistant.config_entries import ConfigEntry
from homeassistant.const import (
ATTR_ENTITY_ID,
EVENT_HOMEASSISTANT_STOP,
PLATFORM_FORMAT,
STATE_UNAVAILABLE,
STATE_UNKNOWN,
)
from homeassistant.const import EVENT_HOMEASSISTANT_STOP, PLATFORM_FORMAT
from homeassistant.core import (
CALLBACK_TYPE,
Event,
@ -58,9 +42,8 @@ from homeassistant.helpers import config_validation as cv
from homeassistant.helpers.entity_component import EntityComponent
from homeassistant.helpers.event import async_call_later
from homeassistant.helpers.network import get_url
from homeassistant.helpers.restore_state import RestoreEntity
from homeassistant.helpers.typing import UNDEFINED, ConfigType
from homeassistant.util import dt as dt_util, language as language_util
from homeassistant.util import language as language_util
from .const import (
ATTR_CACHE,
@ -78,6 +61,7 @@ from .const import (
DOMAIN,
TtsAudioType,
)
from .entity import TextToSpeechEntity
from .helper import get_engine_instance
from .legacy import PLATFORM_SCHEMA, PLATFORM_SCHEMA_BASE, Provider, async_setup_legacy
from .media_source import generate_media_source_id, media_source_id_to_kwargs
@ -95,6 +79,7 @@ __all__ = [
"PLATFORM_SCHEMA_BASE",
"Provider",
"SampleFormat",
"TextToSpeechEntity",
"TtsAudioType",
"Voice",
"async_default_engine",
@ -389,14 +374,6 @@ async def async_unload_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
return await hass.data[DATA_COMPONENT].async_unload_entry(entry)
CACHED_PROPERTIES_WITH_ATTR_ = {
"default_language",
"default_options",
"supported_languages",
"supported_options",
}
@dataclass
class ResultStream:
"""Class that will stream the result when available."""
@ -431,134 +408,6 @@ class ResultStream:
return data
class TextToSpeechEntity(RestoreEntity, cached_properties=CACHED_PROPERTIES_WITH_ATTR_):
"""Represent a single TTS engine."""
_attr_should_poll = False
__last_tts_loaded: str | None = None
_attr_default_language: str
_attr_default_options: Mapping[str, Any] | None = None
_attr_supported_languages: list[str]
_attr_supported_options: list[str] | None = None
@property
@final
def state(self) -> str | None:
"""Return the state of the entity."""
if self.__last_tts_loaded is None:
return None
return self.__last_tts_loaded
@cached_property
def supported_languages(self) -> list[str]:
"""Return a list of supported languages."""
return self._attr_supported_languages
@cached_property
def default_language(self) -> str:
"""Return the default language."""
return self._attr_default_language
@cached_property
def supported_options(self) -> list[str] | None:
"""Return a list of supported options like voice, emotions."""
return self._attr_supported_options
@cached_property
def default_options(self) -> Mapping[str, Any] | None:
"""Return a mapping with the default options."""
return self._attr_default_options
@callback
def async_get_supported_voices(self, language: str) -> list[Voice] | None:
"""Return a list of supported voices for a language."""
return None
async def async_internal_added_to_hass(self) -> None:
"""Call when the entity is added to hass."""
await super().async_internal_added_to_hass()
try:
_ = self.default_language
except AttributeError as err:
raise AttributeError(
"TTS entities must either set the '_attr_default_language' attribute or override the 'default_language' property"
) from err
try:
_ = self.supported_languages
except AttributeError as err:
raise AttributeError(
"TTS entities must either set the '_attr_supported_languages' attribute or override the 'supported_languages' property"
) from err
state = await self.async_get_last_state()
if (
state is not None
and state.state is not None
and state.state not in (STATE_UNAVAILABLE, STATE_UNKNOWN)
):
self.__last_tts_loaded = state.state
async def async_speak(
self,
media_player_entity_id: list[str],
message: str,
cache: bool,
language: str | None = None,
options: dict | None = None,
) -> None:
"""Speak via a Media Player."""
await self.hass.services.async_call(
DOMAIN_MP,
SERVICE_PLAY_MEDIA,
{
ATTR_ENTITY_ID: media_player_entity_id,
ATTR_MEDIA_CONTENT_ID: generate_media_source_id(
self.hass,
message=message,
engine=self.entity_id,
language=language,
options=options,
cache=cache,
),
ATTR_MEDIA_CONTENT_TYPE: MediaType.MUSIC,
ATTR_MEDIA_ANNOUNCE: True,
},
blocking=True,
context=self._context,
)
@final
async def internal_async_get_tts_audio(
self, message: str, language: str, options: dict[str, Any]
) -> TtsAudioType:
"""Process an audio stream to TTS service.
Only streaming content is allowed!
"""
self.__last_tts_loaded = dt_util.utcnow().isoformat()
self.async_write_ha_state()
return await self.async_get_tts_audio(
message=message, language=language, options=options
)
def get_tts_audio(
self, message: str, language: str, options: dict[str, Any]
) -> TtsAudioType:
"""Load tts audio file from the engine."""
raise NotImplementedError
async def async_get_tts_audio(
self, message: str, language: str, options: dict[str, Any]
) -> TtsAudioType:
"""Load tts audio file from the engine.
Return a tuple of file extension and data as bytes.
"""
return await self.hass.async_add_executor_job(
partial(self.get_tts_audio, message, language, options=options)
)
def _hash_options(options: dict) -> str:
"""Hashes an options dictionary."""
opts_hash = hashlib.blake2s(digest_size=5)

View File

@ -0,0 +1,159 @@
"""Entity for Text-to-Speech."""
from collections.abc import Mapping
from functools import partial
from typing import Any, final
from propcache.api import cached_property
from homeassistant.components.media_player import (
ATTR_MEDIA_ANNOUNCE,
ATTR_MEDIA_CONTENT_ID,
ATTR_MEDIA_CONTENT_TYPE,
DOMAIN as DOMAIN_MP,
SERVICE_PLAY_MEDIA,
MediaType,
)
from homeassistant.const import ATTR_ENTITY_ID, STATE_UNAVAILABLE, STATE_UNKNOWN
from homeassistant.core import callback
from homeassistant.helpers.restore_state import RestoreEntity
from homeassistant.util import dt as dt_util
from .const import TtsAudioType
from .media_source import generate_media_source_id
from .models import Voice
CACHED_PROPERTIES_WITH_ATTR_ = {
"default_language",
"default_options",
"supported_languages",
"supported_options",
}
class TextToSpeechEntity(RestoreEntity, cached_properties=CACHED_PROPERTIES_WITH_ATTR_):
"""Represent a single TTS engine."""
_attr_should_poll = False
__last_tts_loaded: str | None = None
_attr_default_language: str
_attr_default_options: Mapping[str, Any] | None = None
_attr_supported_languages: list[str]
_attr_supported_options: list[str] | None = None
@property
@final
def state(self) -> str | None:
"""Return the state of the entity."""
if self.__last_tts_loaded is None:
return None
return self.__last_tts_loaded
@cached_property
def supported_languages(self) -> list[str]:
"""Return a list of supported languages."""
return self._attr_supported_languages
@cached_property
def default_language(self) -> str:
"""Return the default language."""
return self._attr_default_language
@cached_property
def supported_options(self) -> list[str] | None:
"""Return a list of supported options like voice, emotions."""
return self._attr_supported_options
@cached_property
def default_options(self) -> Mapping[str, Any] | None:
"""Return a mapping with the default options."""
return self._attr_default_options
@callback
def async_get_supported_voices(self, language: str) -> list[Voice] | None:
"""Return a list of supported voices for a language."""
return None
async def async_internal_added_to_hass(self) -> None:
"""Call when the entity is added to hass."""
await super().async_internal_added_to_hass()
try:
_ = self.default_language
except AttributeError as err:
raise AttributeError(
"TTS entities must either set the '_attr_default_language' attribute or override the 'default_language' property"
) from err
try:
_ = self.supported_languages
except AttributeError as err:
raise AttributeError(
"TTS entities must either set the '_attr_supported_languages' attribute or override the 'supported_languages' property"
) from err
state = await self.async_get_last_state()
if (
state is not None
and state.state is not None
and state.state not in (STATE_UNAVAILABLE, STATE_UNKNOWN)
):
self.__last_tts_loaded = state.state
async def async_speak(
self,
media_player_entity_id: list[str],
message: str,
cache: bool,
language: str | None = None,
options: dict | None = None,
) -> None:
"""Speak via a Media Player."""
await self.hass.services.async_call(
DOMAIN_MP,
SERVICE_PLAY_MEDIA,
{
ATTR_ENTITY_ID: media_player_entity_id,
ATTR_MEDIA_CONTENT_ID: generate_media_source_id(
self.hass,
message=message,
engine=self.entity_id,
language=language,
options=options,
cache=cache,
),
ATTR_MEDIA_CONTENT_TYPE: MediaType.MUSIC,
ATTR_MEDIA_ANNOUNCE: True,
},
blocking=True,
context=self._context,
)
@final
async def internal_async_get_tts_audio(
self, message: str, language: str, options: dict[str, Any]
) -> TtsAudioType:
"""Process an audio stream to TTS service.
Only streaming content is allowed!
"""
self.__last_tts_loaded = dt_util.utcnow().isoformat()
self.async_write_ha_state()
return await self.async_get_tts_audio(
message=message, language=language, options=options
)
def get_tts_audio(
self, message: str, language: str, options: dict[str, Any]
) -> TtsAudioType:
"""Load tts audio file from the engine."""
raise NotImplementedError
async def async_get_tts_audio(
self, message: str, language: str, options: dict[str, Any]
) -> TtsAudioType:
"""Load tts audio file from the engine.
Return a tuple of file extension and data as bytes.
"""
return await self.hass.async_add_executor_job(
partial(self.get_tts_audio, message, language, options=options)
)

View File

@ -0,0 +1,144 @@
"""Tests for the TTS entity."""
import pytest
from homeassistant.components import tts
from homeassistant.config_entries import ConfigEntryState
from homeassistant.core import HomeAssistant, State
from .common import (
DEFAULT_LANG,
SUPPORT_LANGUAGES,
TEST_DOMAIN,
MockTTSEntity,
mock_config_entry_setup,
)
from tests.common import mock_restore_cache
class DefaultEntity(tts.TextToSpeechEntity):
"""Test entity."""
_attr_supported_languages = SUPPORT_LANGUAGES
_attr_default_language = DEFAULT_LANG
async def test_default_entity_attributes() -> None:
"""Test default entity attributes."""
entity = DefaultEntity()
assert entity.hass is None
assert entity.default_language == DEFAULT_LANG
assert entity.supported_languages == SUPPORT_LANGUAGES
assert entity.supported_options is None
assert entity.default_options is None
assert entity.async_get_supported_voices("test") is None
async def test_restore_state(
hass: HomeAssistant,
mock_tts_entity: MockTTSEntity,
) -> None:
"""Test we restore state in the integration."""
entity_id = f"{tts.DOMAIN}.{TEST_DOMAIN}"
timestamp = "2023-01-01T23:59:59+00:00"
mock_restore_cache(hass, (State(entity_id, timestamp),))
config_entry = await mock_config_entry_setup(hass, mock_tts_entity)
await hass.async_block_till_done()
assert config_entry.state is ConfigEntryState.LOADED
state = hass.states.get(entity_id)
assert state
assert state.state == timestamp
async def test_tts_entity_subclass_properties(
hass: HomeAssistant, caplog: pytest.LogCaptureFixture
) -> None:
"""Test for errors when subclasses of the TextToSpeechEntity are missing required properties."""
class TestClass1(tts.TextToSpeechEntity):
_attr_default_language = DEFAULT_LANG
_attr_supported_languages = SUPPORT_LANGUAGES
await mock_config_entry_setup(hass, TestClass1())
class TestClass2(tts.TextToSpeechEntity):
@property
def default_language(self) -> str:
return DEFAULT_LANG
@property
def supported_languages(self) -> list[str]:
return SUPPORT_LANGUAGES
await mock_config_entry_setup(hass, TestClass2())
assert all(record.exc_info is None for record in caplog.records)
caplog.clear()
class TestClass3(tts.TextToSpeechEntity):
_attr_default_language = DEFAULT_LANG
await mock_config_entry_setup(hass, TestClass3())
assert (
"TTS entities must either set the '_attr_supported_languages' attribute or override the 'supported_languages' property"
in [
str(record.exc_info[1])
for record in caplog.records
if record.exc_info is not None
]
)
caplog.clear()
class TestClass4(tts.TextToSpeechEntity):
_attr_supported_languages = SUPPORT_LANGUAGES
await mock_config_entry_setup(hass, TestClass4())
assert (
"TTS entities must either set the '_attr_default_language' attribute or override the 'default_language' property"
in [
str(record.exc_info[1])
for record in caplog.records
if record.exc_info is not None
]
)
caplog.clear()
class TestClass5(tts.TextToSpeechEntity):
@property
def default_language(self) -> str:
return DEFAULT_LANG
await mock_config_entry_setup(hass, TestClass5())
assert (
"TTS entities must either set the '_attr_supported_languages' attribute or override the 'supported_languages' property"
in [
str(record.exc_info[1])
for record in caplog.records
if record.exc_info is not None
]
)
caplog.clear()
class TestClass6(tts.TextToSpeechEntity):
@property
def supported_languages(self) -> list[str]:
return SUPPORT_LANGUAGES
await mock_config_entry_setup(hass, TestClass6())
assert (
"TTS entities must either set the '_attr_default_language' attribute or override the 'default_language' property"
in [
str(record.exc_info[1])
for record in caplog.records
if record.exc_info is not None
]
)

View File

@ -20,14 +20,13 @@ from homeassistant.components.media_player import (
)
from homeassistant.config_entries import ConfigEntryState
from homeassistant.const import ATTR_ENTITY_ID, STATE_UNKNOWN
from homeassistant.core import HomeAssistant, State
from homeassistant.core import HomeAssistant
from homeassistant.exceptions import HomeAssistantError
from homeassistant.setup import async_setup_component
from homeassistant.util import dt as dt_util
from .common import (
DEFAULT_LANG,
SUPPORT_LANGUAGES,
TEST_DOMAIN,
MockTTS,
MockTTSEntity,
@ -38,37 +37,12 @@ from .common import (
retrieve_media,
)
from tests.common import (
MockModule,
async_mock_service,
mock_integration,
mock_platform,
mock_restore_cache,
)
from tests.common import MockModule, async_mock_service, mock_integration, mock_platform
from tests.typing import ClientSessionGenerator, WebSocketGenerator
ORIG_WRITE_TAGS = tts.SpeechManager.write_tags
class DefaultEntity(tts.TextToSpeechEntity):
"""Test entity."""
_attr_supported_languages = SUPPORT_LANGUAGES
_attr_default_language = DEFAULT_LANG
async def test_default_entity_attributes() -> None:
"""Test default entity attributes."""
entity = DefaultEntity()
assert entity.hass is None
assert entity.default_language == DEFAULT_LANG
assert entity.supported_languages == SUPPORT_LANGUAGES
assert entity.supported_options is None
assert entity.default_options is None
assert entity.async_get_supported_voices("test") is None
async def test_config_entry_unload(
hass: HomeAssistant,
hass_client: ClientSessionGenerator,
@ -120,24 +94,6 @@ async def test_config_entry_unload(
assert state is None
async def test_restore_state(
hass: HomeAssistant,
mock_tts_entity: MockTTSEntity,
) -> None:
"""Test we restore state in the integration."""
entity_id = f"{tts.DOMAIN}.{TEST_DOMAIN}"
timestamp = "2023-01-01T23:59:59+00:00"
mock_restore_cache(hass, (State(entity_id, timestamp),))
config_entry = await mock_config_entry_setup(hass, mock_tts_entity)
await hass.async_block_till_done()
assert config_entry.state is ConfigEntryState.LOADED
state = hass.states.get(entity_id)
assert state
assert state.state == timestamp
@pytest.mark.parametrize(
"setup", ["mock_setup", "mock_config_entry_setup"], indirect=True
)
@ -1840,96 +1796,6 @@ async def test_async_convert_audio_error(hass: HomeAssistant) -> None:
await tts.async_convert_audio(hass, "wav", bytes(0), "mp3")
async def test_ttsentity_subclass_properties(
hass: HomeAssistant, caplog: pytest.LogCaptureFixture
) -> None:
"""Test for errors when subclasses of the TextToSpeechEntity are missing required properties."""
class TestClass1(tts.TextToSpeechEntity):
_attr_default_language = DEFAULT_LANG
_attr_supported_languages = SUPPORT_LANGUAGES
await mock_config_entry_setup(hass, TestClass1())
class TestClass2(tts.TextToSpeechEntity):
@property
def default_language(self) -> str:
return DEFAULT_LANG
@property
def supported_languages(self) -> list[str]:
return SUPPORT_LANGUAGES
await mock_config_entry_setup(hass, TestClass2())
assert all(record.exc_info is None for record in caplog.records)
caplog.clear()
class TestClass3(tts.TextToSpeechEntity):
_attr_default_language = DEFAULT_LANG
await mock_config_entry_setup(hass, TestClass3())
assert (
"TTS entities must either set the '_attr_supported_languages' attribute or override the 'supported_languages' property"
in [
str(record.exc_info[1])
for record in caplog.records
if record.exc_info is not None
]
)
caplog.clear()
class TestClass4(tts.TextToSpeechEntity):
_attr_supported_languages = SUPPORT_LANGUAGES
await mock_config_entry_setup(hass, TestClass4())
assert (
"TTS entities must either set the '_attr_default_language' attribute or override the 'default_language' property"
in [
str(record.exc_info[1])
for record in caplog.records
if record.exc_info is not None
]
)
caplog.clear()
class TestClass5(tts.TextToSpeechEntity):
@property
def default_language(self) -> str:
return DEFAULT_LANG
await mock_config_entry_setup(hass, TestClass5())
assert (
"TTS entities must either set the '_attr_supported_languages' attribute or override the 'supported_languages' property"
in [
str(record.exc_info[1])
for record in caplog.records
if record.exc_info is not None
]
)
caplog.clear()
class TestClass6(tts.TextToSpeechEntity):
@property
def supported_languages(self) -> list[str]:
return SUPPORT_LANGUAGES
await mock_config_entry_setup(hass, TestClass6())
assert (
"TTS entities must either set the '_attr_default_language' attribute or override the 'default_language' property"
in [
str(record.exc_info[1])
for record in caplog.records
if record.exc_info is not None
]
)
async def test_default_engine_prefer_entity(
hass: HomeAssistant,
mock_tts_entity: MockTTSEntity,