From 3b1b33d7f20adb83f94af0d834809bc24d2e3715 Mon Sep 17 00:00:00 2001 From: Paulus Schoutsen Date: Tue, 22 Apr 2025 10:32:44 -0400 Subject: [PATCH] Move method to be test suite only --- homeassistant/components/tts/__init__.py | 16 +------------- tests/components/tts/common.py | 14 ++++++++++++ tests/components/tts/test_init.py | 27 +++++++++--------------- tests/components/wyoming/test_tts.py | 12 ++++++----- 4 files changed, 32 insertions(+), 37 deletions(-) diff --git a/homeassistant/components/tts/__init__.py b/homeassistant/components/tts/__init__.py index 8182d375f96..b000e507509 100644 --- a/homeassistant/components/tts/__init__.py +++ b/homeassistant/components/tts/__init__.py @@ -63,7 +63,7 @@ from .const import ( from .entity import TextToSpeechEntity, TTSAudioRequest from .helper import get_engine_instance from .legacy import PLATFORM_SCHEMA, PLATFORM_SCHEMA_BASE, Provider, async_setup_legacy -from .media_source import generate_media_source_id, media_source_id_to_kwargs +from .media_source import generate_media_source_id from .models import Voice __all__ = [ @@ -83,7 +83,6 @@ __all__ = [ "TtsAudioType", "Voice", "async_default_engine", - "async_get_media_source_audio", "generate_media_source_id", ] @@ -267,19 +266,6 @@ def async_get_stream(hass: HomeAssistant, token: str) -> ResultStream | None: return hass.data[DATA_TTS_MANAGER].token_to_stream.get(token) -async def async_get_media_source_audio( - hass: HomeAssistant, - media_source_id: str, -) -> tuple[str, bytes]: - """Get TTS audio as extension, data.""" - manager = hass.data[DATA_TTS_MANAGER] - cache = manager.async_cache_message_in_memory( - **media_source_id_to_kwargs(media_source_id) - ) - data = b"".join([chunk async for chunk in cache.async_stream_data()]) - return cache.extension, data - - @callback def async_get_text_to_speech_languages(hass: HomeAssistant) -> set[str]: """Return a set with the union of languages supported by tts engines.""" diff --git a/tests/components/tts/common.py b/tests/components/tts/common.py index 99c698771f7..9821d9389d2 100644 --- a/tests/components/tts/common.py +++ b/tests/components/tts/common.py @@ -24,6 +24,7 @@ from homeassistant.components.tts import ( Voice, _get_cache_files, ) +from homeassistant.components.tts.media_source import media_source_id_to_kwargs from homeassistant.config_entries import ConfigEntry from homeassistant.core import HomeAssistant, callback from homeassistant.helpers.entity_platform import AddConfigEntryEntitiesCallback @@ -44,6 +45,19 @@ SUPPORT_LANGUAGES = ["de_CH", "de_DE", "en_GB", "en_US"] TEST_DOMAIN = "test" +async def async_get_media_source_audio( + hass: HomeAssistant, + media_source_id: str, +) -> tuple[str, bytes]: + """Get TTS audio as extension, data.""" + manager = hass.data[DATA_TTS_MANAGER] + cache = manager.async_cache_message_in_memory( + **media_source_id_to_kwargs(media_source_id) + ) + data = b"".join([chunk async for chunk in cache.async_stream_data()]) + return cache.extension, data + + def mock_tts_get_cache_files_fixture_helper() -> Generator[MagicMock]: """Mock the list TTS cache function.""" with patch( diff --git a/tests/components/tts/test_init.py b/tests/components/tts/test_init.py index 4e17bc68a5e..5bfd6e272ea 100644 --- a/tests/components/tts/test_init.py +++ b/tests/components/tts/test_init.py @@ -32,6 +32,7 @@ from .common import ( MockTTS, MockTTSEntity, MockTTSProvider, + async_get_media_source_audio, get_media_source_url, mock_config_entry_setup, mock_setup, @@ -820,7 +821,7 @@ async def test_service_receive_voice( assert req.status == HTTPStatus.OK assert await req.read() == tts_data - extension, data = await tts.async_get_media_source_audio( + extension, data = await async_get_media_source_audio( hass, calls[0].data[ATTR_MEDIA_CONTENT_ID] ) assert extension == "mp3" @@ -1412,12 +1413,8 @@ async def test_legacy_fetching_in_async( cache=None, ) - task = hass.async_create_task( - tts.async_get_media_source_audio(hass, media_source_id) - ) - task2 = hass.async_create_task( - tts.async_get_media_source_audio(hass, media_source_id) - ) + task = hass.async_create_task(async_get_media_source_audio(hass, media_source_id)) + task2 = hass.async_create_task(async_get_media_source_audio(hass, media_source_id)) url = await get_media_source_url(hass, media_source_id) client = await hass_client() @@ -1444,11 +1441,11 @@ async def test_legacy_fetching_in_async( tts_audio = asyncio.Future() tts_audio.set_exception(HomeAssistantError("test error")) with pytest.raises(HomeAssistantError): - assert await tts.async_get_media_source_audio(hass, media_source_id) + assert await async_get_media_source_audio(hass, media_source_id) tts_audio = asyncio.Future() tts_audio.set_result(b"test 2") - assert await tts.async_get_media_source_audio(hass, media_source_id) == ( + assert await async_get_media_source_audio(hass, media_source_id) == ( "mp3", b"test 2", ) @@ -1479,12 +1476,8 @@ async def test_fetching_in_async( cache=None, ) - task = hass.async_create_task( - tts.async_get_media_source_audio(hass, media_source_id) - ) - task2 = hass.async_create_task( - tts.async_get_media_source_audio(hass, media_source_id) - ) + task = hass.async_create_task(async_get_media_source_audio(hass, media_source_id)) + task2 = hass.async_create_task(async_get_media_source_audio(hass, media_source_id)) url = await get_media_source_url(hass, media_source_id) client = await hass_client() @@ -1511,11 +1504,11 @@ async def test_fetching_in_async( tts_audio = asyncio.Future() tts_audio.set_exception(HomeAssistantError("test error")) with pytest.raises(HomeAssistantError): - assert await tts.async_get_media_source_audio(hass, media_source_id) + assert await async_get_media_source_audio(hass, media_source_id) tts_audio = asyncio.Future() tts_audio.set_result(b"test 2") - assert await tts.async_get_media_source_audio(hass, media_source_id) == ( + assert await async_get_media_source_audio(hass, media_source_id) == ( "mp3", b"test 2", ) diff --git a/tests/components/wyoming/test_tts.py b/tests/components/wyoming/test_tts.py index c52b1391038..9717c634d9b 100644 --- a/tests/components/wyoming/test_tts.py +++ b/tests/components/wyoming/test_tts.py @@ -17,6 +17,8 @@ from homeassistant.helpers.entity_component import DATA_INSTANCES from . import MockAsyncTcpClient +from tests.components.tts.common import async_get_media_source_audio + async def test_support(hass: HomeAssistant, init_wyoming_tts) -> None: """Test supported properties.""" @@ -59,7 +61,7 @@ async def test_get_tts_audio( "homeassistant.components.wyoming.tts.AsyncTcpClient", MockAsyncTcpClient(audio_events), ) as mock_client: - extension, data = await tts.async_get_media_source_audio( + extension, data = await async_get_media_source_audio( hass, tts.generate_media_source_id( hass, @@ -96,7 +98,7 @@ async def test_get_tts_audio_different_formats( "homeassistant.components.wyoming.tts.AsyncTcpClient", MockAsyncTcpClient(audio_events), ) as mock_client: - extension, data = await tts.async_get_media_source_audio( + extension, data = await async_get_media_source_audio( hass, tts.generate_media_source_id( hass, @@ -130,7 +132,7 @@ async def test_get_tts_audio_different_formats( "homeassistant.components.wyoming.tts.AsyncTcpClient", MockAsyncTcpClient(audio_events), ) as mock_client: - extension, data = await tts.async_get_media_source_audio( + extension, data = await async_get_media_source_audio( hass, tts.generate_media_source_id( hass, @@ -182,7 +184,7 @@ async def test_get_tts_audio_audio_oserror( HomeAssistantError, ), ): - await tts.async_get_media_source_audio( + await async_get_media_source_audio( hass, tts.generate_media_source_id( hass, "Hello world", "tts.test_tts", hass.config.language @@ -204,7 +206,7 @@ async def test_voice_speaker( "homeassistant.components.wyoming.tts.AsyncTcpClient", MockAsyncTcpClient(audio_events), ) as mock_client: - await tts.async_get_media_source_audio( + await async_get_media_source_audio( hass, tts.generate_media_source_id( hass,