Deduplicate TTS mocks (#124773)

This commit is contained in:
Erik Montnemery 2024-08-28 13:48:49 +02:00 committed by GitHub
parent 38ef216894
commit cff4e46694
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 37 additions and 70 deletions

View File

@ -23,7 +23,7 @@ from homeassistant.components.assist_pipeline.pipeline import (
)
from homeassistant.config_entries import ConfigEntry, ConfigFlow
from homeassistant.const import Platform
from homeassistant.core import HomeAssistant, callback
from homeassistant.core import HomeAssistant
from homeassistant.helpers import device_registry as dr
from homeassistant.helpers.entity_platform import AddEntitiesCallback
from homeassistant.setup import async_setup_component
@ -37,6 +37,7 @@ from tests.common import (
mock_platform,
)
from tests.components.stt.common import MockSTTProvider, MockSTTProviderEntity
from tests.components.tts.common import MockTTSProvider
_TRANSCRIPT = "test transcript"
@ -48,46 +49,6 @@ def mock_tts_cache_dir_autouse(mock_tts_cache_dir: Path) -> None:
"""Mock the TTS cache dir with empty dir."""
class MockTTSProvider(tts.Provider):
"""Mock TTS provider."""
name = "Test"
_supported_languages = ["en-US"]
_supported_voices = {
"en-US": [
tts.Voice("james_earl_jones", "James Earl Jones"),
tts.Voice("fran_drescher", "Fran Drescher"),
]
}
_supported_options = ["voice", "age", tts.ATTR_AUDIO_OUTPUT]
@property
def default_language(self) -> str:
"""Return the default language."""
return "en"
@property
def supported_languages(self) -> list[str]:
"""Return list of supported languages."""
return self._supported_languages
@callback
def async_get_supported_voices(self, language: str) -> list[tts.Voice] | None:
"""Return a list of supported voices for a language."""
return self._supported_voices.get(language)
@property
def supported_options(self) -> list[str]:
"""Return list of supported options like voice, emotions."""
return self._supported_options
def get_tts_audio(
self, message: str, language: str, options: dict[str, Any]
) -> tts.TtsAudioType:
"""Load TTS data."""
return ("mp3", b"")
class MockTTSPlatform(MockPlatform):
"""A mock TTS platform."""
@ -102,7 +63,9 @@ class MockTTSPlatform(MockPlatform):
@pytest.fixture
async def mock_tts_provider() -> MockTTSProvider:
"""Mock TTS provider."""
return MockTTSProvider()
provider = MockTTSProvider("en")
provider._supported_languages = ["en-US"]
return provider
@pytest.fixture

View File

@ -130,6 +130,8 @@ class BaseProvider:
def __init__(self, lang: str) -> None:
"""Initialize test provider."""
self._lang = lang
self._supported_languages = SUPPORT_LANGUAGES
self._supported_options = ["voice", "age"]
@property
def default_language(self) -> str:
@ -139,7 +141,7 @@ class BaseProvider:
@property
def supported_languages(self) -> list[str]:
"""Return list of supported languages."""
return SUPPORT_LANGUAGES
return self._supported_languages
@callback
def async_get_supported_voices(self, language: str) -> list[Voice] | None:
@ -154,7 +156,7 @@ class BaseProvider:
@property
def supported_options(self) -> list[str]:
"""Return list of supported options like voice, emotions."""
return ["voice", "age"]
return self._supported_options
def get_tts_audio(
self, message: str, language: str, options: dict[str, Any]
@ -163,7 +165,7 @@ class BaseProvider:
return ("mp3", b"")
class MockProvider(BaseProvider, Provider):
class MockTTSProvider(BaseProvider, Provider):
"""Test speech API provider."""
def __init__(self, lang: str) -> None:
@ -185,7 +187,7 @@ class MockTTS(MockPlatform):
{vol.Optional(CONF_LANG, default=DEFAULT_LANG): vol.In(SUPPORT_LANGUAGES)}
)
def __init__(self, provider: MockProvider, **kwargs: Any) -> None:
def __init__(self, provider: MockTTSProvider, **kwargs: Any) -> None:
"""Initialize."""
super().__init__(**kwargs)
self._provider = provider
@ -202,7 +204,7 @@ class MockTTS(MockPlatform):
async def mock_setup(
hass: HomeAssistant,
mock_provider: MockProvider,
mock_provider: MockTTSProvider,
) -> None:
"""Set up a test provider."""
mock_integration(hass, MockModule(domain=TEST_DOMAIN))

View File

@ -17,9 +17,9 @@ from homeassistant.core import HomeAssistant
from .common import (
DEFAULT_LANG,
TEST_DOMAIN,
MockProvider,
MockTTS,
MockTTSEntity,
MockTTSProvider,
mock_config_entry_setup,
mock_setup,
)
@ -67,9 +67,9 @@ async def mock_tts(hass: HomeAssistant, mock_provider) -> None:
@pytest.fixture
def mock_provider() -> MockProvider:
def mock_provider() -> MockTTSProvider:
"""Test TTS provider."""
return MockProvider(DEFAULT_LANG)
return MockTTSProvider(DEFAULT_LANG)
@pytest.fixture
@ -106,7 +106,7 @@ def config_flow_fixture(
async def setup_fixture(
hass: HomeAssistant,
request: pytest.FixtureRequest,
mock_provider: MockProvider,
mock_provider: MockTTSProvider,
mock_tts_entity: MockTTSEntity,
) -> None:
"""Set up the test environment."""

View File

@ -30,9 +30,9 @@ from .common import (
DEFAULT_LANG,
SUPPORT_LANGUAGES,
TEST_DOMAIN,
MockProvider,
MockTTS,
MockTTSEntity,
MockTTSProvider,
get_media_source_url,
mock_config_entry_setup,
mock_setup,
@ -220,7 +220,7 @@ async def test_service(
@pytest.mark.parametrize(
("mock_provider", "mock_tts_entity"),
[(MockProvider("de_DE"), MockTTSEntity("de_DE"))],
[(MockTTSProvider("de_DE"), MockTTSEntity("de_DE"))],
)
@pytest.mark.parametrize(
("setup", "tts_service", "service_data", "expected_url_suffix"),
@ -281,7 +281,7 @@ async def test_service_default_language(
@pytest.mark.parametrize(
("mock_provider", "mock_tts_entity"),
[(MockProvider("en_US"), MockTTSEntity("en_US"))],
[(MockTTSProvider("en_US"), MockTTSEntity("en_US"))],
)
@pytest.mark.parametrize(
("setup", "tts_service", "service_data", "expected_url_suffix"),
@ -511,7 +511,7 @@ async def test_service_options(
).is_file()
class MockProviderWithDefaults(MockProvider):
class MockProviderWithDefaults(MockTTSProvider):
"""Mock provider with default options."""
@property
@ -854,7 +854,7 @@ async def test_service_receive_voice(
@pytest.mark.parametrize(
("mock_provider", "mock_tts_entity"),
[(MockProvider("de_DE"), MockTTSEntity("de_DE"))],
[(MockTTSProvider("de_DE"), MockTTSEntity("de_DE"))],
)
@pytest.mark.parametrize(
("setup", "tts_service", "service_data", "expected_url_suffix"),
@ -1015,7 +1015,7 @@ async def test_service_without_cache(
).is_file()
class MockProviderBoom(MockProvider):
class MockProviderBoom(MockTTSProvider):
"""Mock provider that blows up."""
def get_tts_audio(
@ -1041,7 +1041,7 @@ class MockEntityBoom(MockTTSEntity):
async def test_setup_legacy_cache_dir(
hass: HomeAssistant,
mock_tts_cache_dir: Path,
mock_provider: MockProvider,
mock_provider: MockTTSProvider,
) -> None:
"""Set up a TTS platform with cache and call service without cache."""
calls = async_mock_service(hass, DOMAIN_MP, SERVICE_PLAY_MEDIA)
@ -1106,7 +1106,7 @@ async def test_setup_cache_dir(
await hass.async_block_till_done()
class MockProviderEmpty(MockProvider):
class MockProviderEmpty(MockTTSProvider):
"""Mock provider with empty get_tts_audio."""
def get_tts_audio(
@ -1178,7 +1178,7 @@ async def test_service_get_tts_error(
async def test_load_cache_legacy_retrieve_without_mem_cache(
hass: HomeAssistant,
mock_provider: MockProvider,
mock_provider: MockTTSProvider,
mock_tts_cache_dir: Path,
hass_client: ClientSessionGenerator,
) -> None:
@ -1426,7 +1426,7 @@ async def test_legacy_fetching_in_async(
"""Test async fetching of data for a legacy provider."""
tts_audio: asyncio.Future[bytes] = asyncio.Future()
class ProviderWithAsyncFetching(MockProvider):
class ProviderWithAsyncFetching(MockTTSProvider):
"""Provider that supports audio output option."""
@property
@ -1662,8 +1662,8 @@ async def test_ws_list_engines_deprecated(
also provides tts entities.
"""
mock_provider = MockProvider(DEFAULT_LANG)
mock_provider_2 = MockProvider(DEFAULT_LANG)
mock_provider = MockTTSProvider(DEFAULT_LANG)
mock_provider_2 = MockTTSProvider(DEFAULT_LANG)
mock_integration(hass, MockModule(domain="test"))
mock_platform(hass, "test.tts", MockTTS(mock_provider))
mock_integration(hass, MockModule(domain="test_2"))
@ -1910,7 +1910,7 @@ async def test_ttsentity_subclass_properties(
async def test_default_engine_prefer_entity(
hass: HomeAssistant,
mock_tts_entity: MockTTSEntity,
mock_provider: MockProvider,
mock_provider: MockTTSProvider,
) -> None:
"""Test async_default_engine.
@ -1941,7 +1941,7 @@ async def test_default_engine_prefer_entity(
)
async def test_default_engine_prefer_cloud_entity(
hass: HomeAssistant,
mock_provider: MockProvider,
mock_provider: MockTTSProvider,
config_flow_test_domains: str,
) -> None:
"""Test async_default_engine.

View File

@ -17,7 +17,7 @@ from homeassistant.helpers.discovery import async_load_platform
from homeassistant.helpers.typing import ConfigType, DiscoveryInfoType
from homeassistant.setup import async_setup_component
from .common import SUPPORT_LANGUAGES, MockProvider, MockTTS
from .common import SUPPORT_LANGUAGES, MockTTS, MockTTSProvider
from tests.common import (
MockModule,
@ -75,7 +75,9 @@ async def test_invalid_platform(
async def test_platform_setup_without_provider(
hass: HomeAssistant, caplog: pytest.LogCaptureFixture, mock_provider: MockProvider
hass: HomeAssistant,
caplog: pytest.LogCaptureFixture,
mock_provider: MockTTSProvider,
) -> None:
"""Test platform setup without provider returned."""
@ -109,7 +111,7 @@ async def test_platform_setup_without_provider(
async def test_platform_setup_with_error(
hass: HomeAssistant,
caplog: pytest.LogCaptureFixture,
mock_provider: MockProvider,
mock_provider: MockTTSProvider,
) -> None:
"""Test platform setup with an error during setup."""

View File

@ -12,8 +12,8 @@ from homeassistant.setup import async_setup_component
from .common import (
DEFAULT_LANG,
MockProvider,
MockTTSEntity,
MockTTSProvider,
mock_config_entry_setup,
mock_setup,
retrieve_media,
@ -28,7 +28,7 @@ class MSEntity(MockTTSEntity):
get_tts_audio = MagicMock(return_value=("mp3", b""))
class MSProvider(MockProvider):
class MSProvider(MockTTSProvider):
"""Test speech API provider."""
get_tts_audio = MagicMock(return_value=("mp3", b""))