Move method to be test suite only

This commit is contained in:
Paulus Schoutsen 2025-04-22 10:32:44 -04:00
parent 8aa30b0ccb
commit 3b1b33d7f2
4 changed files with 32 additions and 37 deletions

View File

@ -63,7 +63,7 @@ from .const import (
from .entity import TextToSpeechEntity, TTSAudioRequest from .entity import TextToSpeechEntity, TTSAudioRequest
from .helper import get_engine_instance from .helper import get_engine_instance
from .legacy import PLATFORM_SCHEMA, PLATFORM_SCHEMA_BASE, Provider, async_setup_legacy from .legacy import PLATFORM_SCHEMA, PLATFORM_SCHEMA_BASE, Provider, async_setup_legacy
from .media_source import generate_media_source_id, media_source_id_to_kwargs from .media_source import generate_media_source_id
from .models import Voice from .models import Voice
__all__ = [ __all__ = [
@ -83,7 +83,6 @@ __all__ = [
"TtsAudioType", "TtsAudioType",
"Voice", "Voice",
"async_default_engine", "async_default_engine",
"async_get_media_source_audio",
"generate_media_source_id", "generate_media_source_id",
] ]
@ -267,19 +266,6 @@ def async_get_stream(hass: HomeAssistant, token: str) -> ResultStream | None:
return hass.data[DATA_TTS_MANAGER].token_to_stream.get(token) return hass.data[DATA_TTS_MANAGER].token_to_stream.get(token)
async def async_get_media_source_audio(
hass: HomeAssistant,
media_source_id: str,
) -> tuple[str, bytes]:
"""Get TTS audio as extension, data."""
manager = hass.data[DATA_TTS_MANAGER]
cache = manager.async_cache_message_in_memory(
**media_source_id_to_kwargs(media_source_id)
)
data = b"".join([chunk async for chunk in cache.async_stream_data()])
return cache.extension, data
@callback @callback
def async_get_text_to_speech_languages(hass: HomeAssistant) -> set[str]: def async_get_text_to_speech_languages(hass: HomeAssistant) -> set[str]:
"""Return a set with the union of languages supported by tts engines.""" """Return a set with the union of languages supported by tts engines."""

View File

@ -24,6 +24,7 @@ from homeassistant.components.tts import (
Voice, Voice,
_get_cache_files, _get_cache_files,
) )
from homeassistant.components.tts.media_source import media_source_id_to_kwargs
from homeassistant.config_entries import ConfigEntry from homeassistant.config_entries import ConfigEntry
from homeassistant.core import HomeAssistant, callback from homeassistant.core import HomeAssistant, callback
from homeassistant.helpers.entity_platform import AddConfigEntryEntitiesCallback from homeassistant.helpers.entity_platform import AddConfigEntryEntitiesCallback
@ -44,6 +45,19 @@ SUPPORT_LANGUAGES = ["de_CH", "de_DE", "en_GB", "en_US"]
TEST_DOMAIN = "test" TEST_DOMAIN = "test"
async def async_get_media_source_audio(
hass: HomeAssistant,
media_source_id: str,
) -> tuple[str, bytes]:
"""Get TTS audio as extension, data."""
manager = hass.data[DATA_TTS_MANAGER]
cache = manager.async_cache_message_in_memory(
**media_source_id_to_kwargs(media_source_id)
)
data = b"".join([chunk async for chunk in cache.async_stream_data()])
return cache.extension, data
def mock_tts_get_cache_files_fixture_helper() -> Generator[MagicMock]: def mock_tts_get_cache_files_fixture_helper() -> Generator[MagicMock]:
"""Mock the list TTS cache function.""" """Mock the list TTS cache function."""
with patch( with patch(

View File

@ -32,6 +32,7 @@ from .common import (
MockTTS, MockTTS,
MockTTSEntity, MockTTSEntity,
MockTTSProvider, MockTTSProvider,
async_get_media_source_audio,
get_media_source_url, get_media_source_url,
mock_config_entry_setup, mock_config_entry_setup,
mock_setup, mock_setup,
@ -820,7 +821,7 @@ async def test_service_receive_voice(
assert req.status == HTTPStatus.OK assert req.status == HTTPStatus.OK
assert await req.read() == tts_data assert await req.read() == tts_data
extension, data = await tts.async_get_media_source_audio( extension, data = await async_get_media_source_audio(
hass, calls[0].data[ATTR_MEDIA_CONTENT_ID] hass, calls[0].data[ATTR_MEDIA_CONTENT_ID]
) )
assert extension == "mp3" assert extension == "mp3"
@ -1412,12 +1413,8 @@ async def test_legacy_fetching_in_async(
cache=None, cache=None,
) )
task = hass.async_create_task( task = hass.async_create_task(async_get_media_source_audio(hass, media_source_id))
tts.async_get_media_source_audio(hass, media_source_id) task2 = hass.async_create_task(async_get_media_source_audio(hass, media_source_id))
)
task2 = hass.async_create_task(
tts.async_get_media_source_audio(hass, media_source_id)
)
url = await get_media_source_url(hass, media_source_id) url = await get_media_source_url(hass, media_source_id)
client = await hass_client() client = await hass_client()
@ -1444,11 +1441,11 @@ async def test_legacy_fetching_in_async(
tts_audio = asyncio.Future() tts_audio = asyncio.Future()
tts_audio.set_exception(HomeAssistantError("test error")) tts_audio.set_exception(HomeAssistantError("test error"))
with pytest.raises(HomeAssistantError): with pytest.raises(HomeAssistantError):
assert await tts.async_get_media_source_audio(hass, media_source_id) assert await async_get_media_source_audio(hass, media_source_id)
tts_audio = asyncio.Future() tts_audio = asyncio.Future()
tts_audio.set_result(b"test 2") tts_audio.set_result(b"test 2")
assert await tts.async_get_media_source_audio(hass, media_source_id) == ( assert await async_get_media_source_audio(hass, media_source_id) == (
"mp3", "mp3",
b"test 2", b"test 2",
) )
@ -1479,12 +1476,8 @@ async def test_fetching_in_async(
cache=None, cache=None,
) )
task = hass.async_create_task( task = hass.async_create_task(async_get_media_source_audio(hass, media_source_id))
tts.async_get_media_source_audio(hass, media_source_id) task2 = hass.async_create_task(async_get_media_source_audio(hass, media_source_id))
)
task2 = hass.async_create_task(
tts.async_get_media_source_audio(hass, media_source_id)
)
url = await get_media_source_url(hass, media_source_id) url = await get_media_source_url(hass, media_source_id)
client = await hass_client() client = await hass_client()
@ -1511,11 +1504,11 @@ async def test_fetching_in_async(
tts_audio = asyncio.Future() tts_audio = asyncio.Future()
tts_audio.set_exception(HomeAssistantError("test error")) tts_audio.set_exception(HomeAssistantError("test error"))
with pytest.raises(HomeAssistantError): with pytest.raises(HomeAssistantError):
assert await tts.async_get_media_source_audio(hass, media_source_id) assert await async_get_media_source_audio(hass, media_source_id)
tts_audio = asyncio.Future() tts_audio = asyncio.Future()
tts_audio.set_result(b"test 2") tts_audio.set_result(b"test 2")
assert await tts.async_get_media_source_audio(hass, media_source_id) == ( assert await async_get_media_source_audio(hass, media_source_id) == (
"mp3", "mp3",
b"test 2", b"test 2",
) )

View File

@ -17,6 +17,8 @@ from homeassistant.helpers.entity_component import DATA_INSTANCES
from . import MockAsyncTcpClient from . import MockAsyncTcpClient
from tests.components.tts.common import async_get_media_source_audio
async def test_support(hass: HomeAssistant, init_wyoming_tts) -> None: async def test_support(hass: HomeAssistant, init_wyoming_tts) -> None:
"""Test supported properties.""" """Test supported properties."""
@ -59,7 +61,7 @@ async def test_get_tts_audio(
"homeassistant.components.wyoming.tts.AsyncTcpClient", "homeassistant.components.wyoming.tts.AsyncTcpClient",
MockAsyncTcpClient(audio_events), MockAsyncTcpClient(audio_events),
) as mock_client: ) as mock_client:
extension, data = await tts.async_get_media_source_audio( extension, data = await async_get_media_source_audio(
hass, hass,
tts.generate_media_source_id( tts.generate_media_source_id(
hass, hass,
@ -96,7 +98,7 @@ async def test_get_tts_audio_different_formats(
"homeassistant.components.wyoming.tts.AsyncTcpClient", "homeassistant.components.wyoming.tts.AsyncTcpClient",
MockAsyncTcpClient(audio_events), MockAsyncTcpClient(audio_events),
) as mock_client: ) as mock_client:
extension, data = await tts.async_get_media_source_audio( extension, data = await async_get_media_source_audio(
hass, hass,
tts.generate_media_source_id( tts.generate_media_source_id(
hass, hass,
@ -130,7 +132,7 @@ async def test_get_tts_audio_different_formats(
"homeassistant.components.wyoming.tts.AsyncTcpClient", "homeassistant.components.wyoming.tts.AsyncTcpClient",
MockAsyncTcpClient(audio_events), MockAsyncTcpClient(audio_events),
) as mock_client: ) as mock_client:
extension, data = await tts.async_get_media_source_audio( extension, data = await async_get_media_source_audio(
hass, hass,
tts.generate_media_source_id( tts.generate_media_source_id(
hass, hass,
@ -182,7 +184,7 @@ async def test_get_tts_audio_audio_oserror(
HomeAssistantError, HomeAssistantError,
), ),
): ):
await tts.async_get_media_source_audio( await async_get_media_source_audio(
hass, hass,
tts.generate_media_source_id( tts.generate_media_source_id(
hass, "Hello world", "tts.test_tts", hass.config.language hass, "Hello world", "tts.test_tts", hass.config.language
@ -204,7 +206,7 @@ async def test_voice_speaker(
"homeassistant.components.wyoming.tts.AsyncTcpClient", "homeassistant.components.wyoming.tts.AsyncTcpClient",
MockAsyncTcpClient(audio_events), MockAsyncTcpClient(audio_events),
) as mock_client: ) as mock_client:
await tts.async_get_media_source_audio( await async_get_media_source_audio(
hass, hass,
tts.generate_media_source_id( tts.generate_media_source_id(
hass, hass,