diff --git a/homeassistant/components/tts/__init__.py b/homeassistant/components/tts/__init__.py index 199d644738b..5b2da44eae2 100644 --- a/homeassistant/components/tts/__init__.py +++ b/homeassistant/components/tts/__init__.py @@ -3,10 +3,8 @@ from __future__ import annotations import asyncio -from collections.abc import Mapping from dataclasses import dataclass from datetime import datetime -from functools import partial import hashlib from http import HTTPStatus import io @@ -18,7 +16,7 @@ import secrets import subprocess import tempfile from time import monotonic -from typing import Any, Final, TypedDict, final +from typing import Any, Final, TypedDict from aiohttp import web import mutagen @@ -28,22 +26,8 @@ import voluptuous as vol from homeassistant.components import ffmpeg, websocket_api from homeassistant.components.http import HomeAssistantView -from homeassistant.components.media_player import ( - ATTR_MEDIA_ANNOUNCE, - ATTR_MEDIA_CONTENT_ID, - ATTR_MEDIA_CONTENT_TYPE, - DOMAIN as DOMAIN_MP, - SERVICE_PLAY_MEDIA, - MediaType, -) from homeassistant.config_entries import ConfigEntry -from homeassistant.const import ( - ATTR_ENTITY_ID, - EVENT_HOMEASSISTANT_STOP, - PLATFORM_FORMAT, - STATE_UNAVAILABLE, - STATE_UNKNOWN, -) +from homeassistant.const import EVENT_HOMEASSISTANT_STOP, PLATFORM_FORMAT from homeassistant.core import ( CALLBACK_TYPE, Event, @@ -58,9 +42,8 @@ from homeassistant.helpers import config_validation as cv from homeassistant.helpers.entity_component import EntityComponent from homeassistant.helpers.event import async_call_later from homeassistant.helpers.network import get_url -from homeassistant.helpers.restore_state import RestoreEntity from homeassistant.helpers.typing import UNDEFINED, ConfigType -from homeassistant.util import dt as dt_util, language as language_util +from homeassistant.util import language as language_util from .const import ( ATTR_CACHE, @@ -78,6 +61,7 @@ from .const import ( DOMAIN, TtsAudioType, ) +from .entity import TextToSpeechEntity from .helper import get_engine_instance from .legacy import PLATFORM_SCHEMA, PLATFORM_SCHEMA_BASE, Provider, async_setup_legacy from .media_source import generate_media_source_id, media_source_id_to_kwargs @@ -95,6 +79,7 @@ __all__ = [ "PLATFORM_SCHEMA_BASE", "Provider", "SampleFormat", + "TextToSpeechEntity", "TtsAudioType", "Voice", "async_default_engine", @@ -389,14 +374,6 @@ async def async_unload_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool: return await hass.data[DATA_COMPONENT].async_unload_entry(entry) -CACHED_PROPERTIES_WITH_ATTR_ = { - "default_language", - "default_options", - "supported_languages", - "supported_options", -} - - @dataclass class ResultStream: """Class that will stream the result when available.""" @@ -431,134 +408,6 @@ class ResultStream: return data -class TextToSpeechEntity(RestoreEntity, cached_properties=CACHED_PROPERTIES_WITH_ATTR_): - """Represent a single TTS engine.""" - - _attr_should_poll = False - __last_tts_loaded: str | None = None - - _attr_default_language: str - _attr_default_options: Mapping[str, Any] | None = None - _attr_supported_languages: list[str] - _attr_supported_options: list[str] | None = None - - @property - @final - def state(self) -> str | None: - """Return the state of the entity.""" - if self.__last_tts_loaded is None: - return None - return self.__last_tts_loaded - - @cached_property - def supported_languages(self) -> list[str]: - """Return a list of supported languages.""" - return self._attr_supported_languages - - @cached_property - def default_language(self) -> str: - """Return the default language.""" - return self._attr_default_language - - @cached_property - def supported_options(self) -> list[str] | None: - """Return a list of supported options like voice, emotions.""" - return self._attr_supported_options - - @cached_property - def default_options(self) -> Mapping[str, Any] | None: - """Return a mapping with the default options.""" - return self._attr_default_options - - @callback - def async_get_supported_voices(self, language: str) -> list[Voice] | None: - """Return a list of supported voices for a language.""" - return None - - async def async_internal_added_to_hass(self) -> None: - """Call when the entity is added to hass.""" - await super().async_internal_added_to_hass() - try: - _ = self.default_language - except AttributeError as err: - raise AttributeError( - "TTS entities must either set the '_attr_default_language' attribute or override the 'default_language' property" - ) from err - try: - _ = self.supported_languages - except AttributeError as err: - raise AttributeError( - "TTS entities must either set the '_attr_supported_languages' attribute or override the 'supported_languages' property" - ) from err - state = await self.async_get_last_state() - if ( - state is not None - and state.state is not None - and state.state not in (STATE_UNAVAILABLE, STATE_UNKNOWN) - ): - self.__last_tts_loaded = state.state - - async def async_speak( - self, - media_player_entity_id: list[str], - message: str, - cache: bool, - language: str | None = None, - options: dict | None = None, - ) -> None: - """Speak via a Media Player.""" - await self.hass.services.async_call( - DOMAIN_MP, - SERVICE_PLAY_MEDIA, - { - ATTR_ENTITY_ID: media_player_entity_id, - ATTR_MEDIA_CONTENT_ID: generate_media_source_id( - self.hass, - message=message, - engine=self.entity_id, - language=language, - options=options, - cache=cache, - ), - ATTR_MEDIA_CONTENT_TYPE: MediaType.MUSIC, - ATTR_MEDIA_ANNOUNCE: True, - }, - blocking=True, - context=self._context, - ) - - @final - async def internal_async_get_tts_audio( - self, message: str, language: str, options: dict[str, Any] - ) -> TtsAudioType: - """Process an audio stream to TTS service. - - Only streaming content is allowed! - """ - self.__last_tts_loaded = dt_util.utcnow().isoformat() - self.async_write_ha_state() - return await self.async_get_tts_audio( - message=message, language=language, options=options - ) - - def get_tts_audio( - self, message: str, language: str, options: dict[str, Any] - ) -> TtsAudioType: - """Load tts audio file from the engine.""" - raise NotImplementedError - - async def async_get_tts_audio( - self, message: str, language: str, options: dict[str, Any] - ) -> TtsAudioType: - """Load tts audio file from the engine. - - Return a tuple of file extension and data as bytes. - """ - return await self.hass.async_add_executor_job( - partial(self.get_tts_audio, message, language, options=options) - ) - - def _hash_options(options: dict) -> str: """Hashes an options dictionary.""" opts_hash = hashlib.blake2s(digest_size=5) diff --git a/homeassistant/components/tts/entity.py b/homeassistant/components/tts/entity.py new file mode 100644 index 00000000000..ef65886452d --- /dev/null +++ b/homeassistant/components/tts/entity.py @@ -0,0 +1,159 @@ +"""Entity for Text-to-Speech.""" + +from collections.abc import Mapping +from functools import partial +from typing import Any, final + +from propcache.api import cached_property + +from homeassistant.components.media_player import ( + ATTR_MEDIA_ANNOUNCE, + ATTR_MEDIA_CONTENT_ID, + ATTR_MEDIA_CONTENT_TYPE, + DOMAIN as DOMAIN_MP, + SERVICE_PLAY_MEDIA, + MediaType, +) +from homeassistant.const import ATTR_ENTITY_ID, STATE_UNAVAILABLE, STATE_UNKNOWN +from homeassistant.core import callback +from homeassistant.helpers.restore_state import RestoreEntity +from homeassistant.util import dt as dt_util + +from .const import TtsAudioType +from .media_source import generate_media_source_id +from .models import Voice + +CACHED_PROPERTIES_WITH_ATTR_ = { + "default_language", + "default_options", + "supported_languages", + "supported_options", +} + + +class TextToSpeechEntity(RestoreEntity, cached_properties=CACHED_PROPERTIES_WITH_ATTR_): + """Represent a single TTS engine.""" + + _attr_should_poll = False + __last_tts_loaded: str | None = None + + _attr_default_language: str + _attr_default_options: Mapping[str, Any] | None = None + _attr_supported_languages: list[str] + _attr_supported_options: list[str] | None = None + + @property + @final + def state(self) -> str | None: + """Return the state of the entity.""" + if self.__last_tts_loaded is None: + return None + return self.__last_tts_loaded + + @cached_property + def supported_languages(self) -> list[str]: + """Return a list of supported languages.""" + return self._attr_supported_languages + + @cached_property + def default_language(self) -> str: + """Return the default language.""" + return self._attr_default_language + + @cached_property + def supported_options(self) -> list[str] | None: + """Return a list of supported options like voice, emotions.""" + return self._attr_supported_options + + @cached_property + def default_options(self) -> Mapping[str, Any] | None: + """Return a mapping with the default options.""" + return self._attr_default_options + + @callback + def async_get_supported_voices(self, language: str) -> list[Voice] | None: + """Return a list of supported voices for a language.""" + return None + + async def async_internal_added_to_hass(self) -> None: + """Call when the entity is added to hass.""" + await super().async_internal_added_to_hass() + try: + _ = self.default_language + except AttributeError as err: + raise AttributeError( + "TTS entities must either set the '_attr_default_language' attribute or override the 'default_language' property" + ) from err + try: + _ = self.supported_languages + except AttributeError as err: + raise AttributeError( + "TTS entities must either set the '_attr_supported_languages' attribute or override the 'supported_languages' property" + ) from err + state = await self.async_get_last_state() + if ( + state is not None + and state.state is not None + and state.state not in (STATE_UNAVAILABLE, STATE_UNKNOWN) + ): + self.__last_tts_loaded = state.state + + async def async_speak( + self, + media_player_entity_id: list[str], + message: str, + cache: bool, + language: str | None = None, + options: dict | None = None, + ) -> None: + """Speak via a Media Player.""" + await self.hass.services.async_call( + DOMAIN_MP, + SERVICE_PLAY_MEDIA, + { + ATTR_ENTITY_ID: media_player_entity_id, + ATTR_MEDIA_CONTENT_ID: generate_media_source_id( + self.hass, + message=message, + engine=self.entity_id, + language=language, + options=options, + cache=cache, + ), + ATTR_MEDIA_CONTENT_TYPE: MediaType.MUSIC, + ATTR_MEDIA_ANNOUNCE: True, + }, + blocking=True, + context=self._context, + ) + + @final + async def internal_async_get_tts_audio( + self, message: str, language: str, options: dict[str, Any] + ) -> TtsAudioType: + """Process an audio stream to TTS service. + + Only streaming content is allowed! + """ + self.__last_tts_loaded = dt_util.utcnow().isoformat() + self.async_write_ha_state() + return await self.async_get_tts_audio( + message=message, language=language, options=options + ) + + def get_tts_audio( + self, message: str, language: str, options: dict[str, Any] + ) -> TtsAudioType: + """Load tts audio file from the engine.""" + raise NotImplementedError + + async def async_get_tts_audio( + self, message: str, language: str, options: dict[str, Any] + ) -> TtsAudioType: + """Load tts audio file from the engine. + + Return a tuple of file extension and data as bytes. + """ + return await self.hass.async_add_executor_job( + partial(self.get_tts_audio, message, language, options=options) + ) diff --git a/tests/components/tts/test_entity.py b/tests/components/tts/test_entity.py new file mode 100644 index 00000000000..d82ec6a5d2b --- /dev/null +++ b/tests/components/tts/test_entity.py @@ -0,0 +1,144 @@ +"""Tests for the TTS entity.""" + +import pytest + +from homeassistant.components import tts +from homeassistant.config_entries import ConfigEntryState +from homeassistant.core import HomeAssistant, State + +from .common import ( + DEFAULT_LANG, + SUPPORT_LANGUAGES, + TEST_DOMAIN, + MockTTSEntity, + mock_config_entry_setup, +) + +from tests.common import mock_restore_cache + + +class DefaultEntity(tts.TextToSpeechEntity): + """Test entity.""" + + _attr_supported_languages = SUPPORT_LANGUAGES + _attr_default_language = DEFAULT_LANG + + +async def test_default_entity_attributes() -> None: + """Test default entity attributes.""" + entity = DefaultEntity() + + assert entity.hass is None + assert entity.default_language == DEFAULT_LANG + assert entity.supported_languages == SUPPORT_LANGUAGES + assert entity.supported_options is None + assert entity.default_options is None + assert entity.async_get_supported_voices("test") is None + + +async def test_restore_state( + hass: HomeAssistant, + mock_tts_entity: MockTTSEntity, +) -> None: + """Test we restore state in the integration.""" + entity_id = f"{tts.DOMAIN}.{TEST_DOMAIN}" + timestamp = "2023-01-01T23:59:59+00:00" + mock_restore_cache(hass, (State(entity_id, timestamp),)) + + config_entry = await mock_config_entry_setup(hass, mock_tts_entity) + await hass.async_block_till_done() + + assert config_entry.state is ConfigEntryState.LOADED + state = hass.states.get(entity_id) + assert state + assert state.state == timestamp + + +async def test_tts_entity_subclass_properties( + hass: HomeAssistant, caplog: pytest.LogCaptureFixture +) -> None: + """Test for errors when subclasses of the TextToSpeechEntity are missing required properties.""" + + class TestClass1(tts.TextToSpeechEntity): + _attr_default_language = DEFAULT_LANG + _attr_supported_languages = SUPPORT_LANGUAGES + + await mock_config_entry_setup(hass, TestClass1()) + + class TestClass2(tts.TextToSpeechEntity): + @property + def default_language(self) -> str: + return DEFAULT_LANG + + @property + def supported_languages(self) -> list[str]: + return SUPPORT_LANGUAGES + + await mock_config_entry_setup(hass, TestClass2()) + + assert all(record.exc_info is None for record in caplog.records) + + caplog.clear() + + class TestClass3(tts.TextToSpeechEntity): + _attr_default_language = DEFAULT_LANG + + await mock_config_entry_setup(hass, TestClass3()) + + assert ( + "TTS entities must either set the '_attr_supported_languages' attribute or override the 'supported_languages' property" + in [ + str(record.exc_info[1]) + for record in caplog.records + if record.exc_info is not None + ] + ) + caplog.clear() + + class TestClass4(tts.TextToSpeechEntity): + _attr_supported_languages = SUPPORT_LANGUAGES + + await mock_config_entry_setup(hass, TestClass4()) + + assert ( + "TTS entities must either set the '_attr_default_language' attribute or override the 'default_language' property" + in [ + str(record.exc_info[1]) + for record in caplog.records + if record.exc_info is not None + ] + ) + caplog.clear() + + class TestClass5(tts.TextToSpeechEntity): + @property + def default_language(self) -> str: + return DEFAULT_LANG + + await mock_config_entry_setup(hass, TestClass5()) + + assert ( + "TTS entities must either set the '_attr_supported_languages' attribute or override the 'supported_languages' property" + in [ + str(record.exc_info[1]) + for record in caplog.records + if record.exc_info is not None + ] + ) + caplog.clear() + + class TestClass6(tts.TextToSpeechEntity): + @property + def supported_languages(self) -> list[str]: + return SUPPORT_LANGUAGES + + await mock_config_entry_setup(hass, TestClass6()) + + assert ( + "TTS entities must either set the '_attr_default_language' attribute or override the 'default_language' property" + in [ + str(record.exc_info[1]) + for record in caplog.records + if record.exc_info is not None + ] + ) diff --git a/tests/components/tts/test_init.py b/tests/components/tts/test_init.py index 86ca2de5791..8dece920907 100644 --- a/tests/components/tts/test_init.py +++ b/tests/components/tts/test_init.py @@ -20,14 +20,13 @@ from homeassistant.components.media_player import ( ) from homeassistant.config_entries import ConfigEntryState from homeassistant.const import ATTR_ENTITY_ID, STATE_UNKNOWN -from homeassistant.core import HomeAssistant, State +from homeassistant.core import HomeAssistant from homeassistant.exceptions import HomeAssistantError from homeassistant.setup import async_setup_component from homeassistant.util import dt as dt_util from .common import ( DEFAULT_LANG, - SUPPORT_LANGUAGES, TEST_DOMAIN, MockTTS, MockTTSEntity, @@ -38,37 +37,12 @@ from .common import ( retrieve_media, ) -from tests.common import ( - MockModule, - async_mock_service, - mock_integration, - mock_platform, - mock_restore_cache, -) +from tests.common import MockModule, async_mock_service, mock_integration, mock_platform from tests.typing import ClientSessionGenerator, WebSocketGenerator ORIG_WRITE_TAGS = tts.SpeechManager.write_tags -class DefaultEntity(tts.TextToSpeechEntity): - """Test entity.""" - - _attr_supported_languages = SUPPORT_LANGUAGES - _attr_default_language = DEFAULT_LANG - - -async def test_default_entity_attributes() -> None: - """Test default entity attributes.""" - entity = DefaultEntity() - - assert entity.hass is None - assert entity.default_language == DEFAULT_LANG - assert entity.supported_languages == SUPPORT_LANGUAGES - assert entity.supported_options is None - assert entity.default_options is None - assert entity.async_get_supported_voices("test") is None - - async def test_config_entry_unload( hass: HomeAssistant, hass_client: ClientSessionGenerator, @@ -120,24 +94,6 @@ async def test_config_entry_unload( assert state is None -async def test_restore_state( - hass: HomeAssistant, - mock_tts_entity: MockTTSEntity, -) -> None: - """Test we restore state in the integration.""" - entity_id = f"{tts.DOMAIN}.{TEST_DOMAIN}" - timestamp = "2023-01-01T23:59:59+00:00" - mock_restore_cache(hass, (State(entity_id, timestamp),)) - - config_entry = await mock_config_entry_setup(hass, mock_tts_entity) - await hass.async_block_till_done() - - assert config_entry.state is ConfigEntryState.LOADED - state = hass.states.get(entity_id) - assert state - assert state.state == timestamp - - @pytest.mark.parametrize( "setup", ["mock_setup", "mock_config_entry_setup"], indirect=True ) @@ -1840,96 +1796,6 @@ async def test_async_convert_audio_error(hass: HomeAssistant) -> None: await tts.async_convert_audio(hass, "wav", bytes(0), "mp3") -async def test_ttsentity_subclass_properties( - hass: HomeAssistant, caplog: pytest.LogCaptureFixture -) -> None: - """Test for errors when subclasses of the TextToSpeechEntity are missing required properties.""" - - class TestClass1(tts.TextToSpeechEntity): - _attr_default_language = DEFAULT_LANG - _attr_supported_languages = SUPPORT_LANGUAGES - - await mock_config_entry_setup(hass, TestClass1()) - - class TestClass2(tts.TextToSpeechEntity): - @property - def default_language(self) -> str: - return DEFAULT_LANG - - @property - def supported_languages(self) -> list[str]: - return SUPPORT_LANGUAGES - - await mock_config_entry_setup(hass, TestClass2()) - - assert all(record.exc_info is None for record in caplog.records) - - caplog.clear() - - class TestClass3(tts.TextToSpeechEntity): - _attr_default_language = DEFAULT_LANG - - await mock_config_entry_setup(hass, TestClass3()) - - assert ( - "TTS entities must either set the '_attr_supported_languages' attribute or override the 'supported_languages' property" - in [ - str(record.exc_info[1]) - for record in caplog.records - if record.exc_info is not None - ] - ) - caplog.clear() - - class TestClass4(tts.TextToSpeechEntity): - _attr_supported_languages = SUPPORT_LANGUAGES - - await mock_config_entry_setup(hass, TestClass4()) - - assert ( - "TTS entities must either set the '_attr_default_language' attribute or override the 'default_language' property" - in [ - str(record.exc_info[1]) - for record in caplog.records - if record.exc_info is not None - ] - ) - caplog.clear() - - class TestClass5(tts.TextToSpeechEntity): - @property - def default_language(self) -> str: - return DEFAULT_LANG - - await mock_config_entry_setup(hass, TestClass5()) - - assert ( - "TTS entities must either set the '_attr_supported_languages' attribute or override the 'supported_languages' property" - in [ - str(record.exc_info[1]) - for record in caplog.records - if record.exc_info is not None - ] - ) - caplog.clear() - - class TestClass6(tts.TextToSpeechEntity): - @property - def supported_languages(self) -> list[str]: - return SUPPORT_LANGUAGES - - await mock_config_entry_setup(hass, TestClass6()) - - assert ( - "TTS entities must either set the '_attr_default_language' attribute or override the 'default_language' property" - in [ - str(record.exc_info[1]) - for record in caplog.records - if record.exc_info is not None - ] - ) - - async def test_default_engine_prefer_entity( hass: HomeAssistant, mock_tts_entity: MockTTSEntity,