From cff4e4669426e077fc263d93d0f269adabe69e76 Mon Sep 17 00:00:00 2001 From: Erik Montnemery Date: Wed, 28 Aug 2024 13:48:49 +0200 Subject: [PATCH] Deduplicate TTS mocks (#124773) --- tests/components/assist_pipeline/conftest.py | 47 +++----------------- tests/components/tts/common.py | 12 ++--- tests/components/tts/conftest.py | 8 ++-- tests/components/tts/test_init.py | 28 ++++++------ tests/components/tts/test_legacy.py | 8 ++-- tests/components/tts/test_media_source.py | 4 +- 6 files changed, 37 insertions(+), 70 deletions(-) diff --git a/tests/components/assist_pipeline/conftest.py b/tests/components/assist_pipeline/conftest.py index c03874c16af..0f6872edbfe 100644 --- a/tests/components/assist_pipeline/conftest.py +++ b/tests/components/assist_pipeline/conftest.py @@ -23,7 +23,7 @@ from homeassistant.components.assist_pipeline.pipeline import ( ) from homeassistant.config_entries import ConfigEntry, ConfigFlow from homeassistant.const import Platform -from homeassistant.core import HomeAssistant, callback +from homeassistant.core import HomeAssistant from homeassistant.helpers import device_registry as dr from homeassistant.helpers.entity_platform import AddEntitiesCallback from homeassistant.setup import async_setup_component @@ -37,6 +37,7 @@ from tests.common import ( mock_platform, ) from tests.components.stt.common import MockSTTProvider, MockSTTProviderEntity +from tests.components.tts.common import MockTTSProvider _TRANSCRIPT = "test transcript" @@ -48,46 +49,6 @@ def mock_tts_cache_dir_autouse(mock_tts_cache_dir: Path) -> None: """Mock the TTS cache dir with empty dir.""" -class MockTTSProvider(tts.Provider): - """Mock TTS provider.""" - - name = "Test" - _supported_languages = ["en-US"] - _supported_voices = { - "en-US": [ - tts.Voice("james_earl_jones", "James Earl Jones"), - tts.Voice("fran_drescher", "Fran Drescher"), - ] - } - _supported_options = ["voice", "age", tts.ATTR_AUDIO_OUTPUT] - - @property - def default_language(self) -> str: - """Return the default language.""" - return "en" - - @property - def supported_languages(self) -> list[str]: - """Return list of supported languages.""" - return self._supported_languages - - @callback - def async_get_supported_voices(self, language: str) -> list[tts.Voice] | None: - """Return a list of supported voices for a language.""" - return self._supported_voices.get(language) - - @property - def supported_options(self) -> list[str]: - """Return list of supported options like voice, emotions.""" - return self._supported_options - - def get_tts_audio( - self, message: str, language: str, options: dict[str, Any] - ) -> tts.TtsAudioType: - """Load TTS data.""" - return ("mp3", b"") - - class MockTTSPlatform(MockPlatform): """A mock TTS platform.""" @@ -102,7 +63,9 @@ class MockTTSPlatform(MockPlatform): @pytest.fixture async def mock_tts_provider() -> MockTTSProvider: """Mock TTS provider.""" - return MockTTSProvider() + provider = MockTTSProvider("en") + provider._supported_languages = ["en-US"] + return provider @pytest.fixture diff --git a/tests/components/tts/common.py b/tests/components/tts/common.py index 4acba401fad..b1eae12d694 100644 --- a/tests/components/tts/common.py +++ b/tests/components/tts/common.py @@ -130,6 +130,8 @@ class BaseProvider: def __init__(self, lang: str) -> None: """Initialize test provider.""" self._lang = lang + self._supported_languages = SUPPORT_LANGUAGES + self._supported_options = ["voice", "age"] @property def default_language(self) -> str: @@ -139,7 +141,7 @@ class BaseProvider: @property def supported_languages(self) -> list[str]: """Return list of supported languages.""" - return SUPPORT_LANGUAGES + return self._supported_languages @callback def async_get_supported_voices(self, language: str) -> list[Voice] | None: @@ -154,7 +156,7 @@ class BaseProvider: @property def supported_options(self) -> list[str]: """Return list of supported options like voice, emotions.""" - return ["voice", "age"] + return self._supported_options def get_tts_audio( self, message: str, language: str, options: dict[str, Any] @@ -163,7 +165,7 @@ class BaseProvider: return ("mp3", b"") -class MockProvider(BaseProvider, Provider): +class MockTTSProvider(BaseProvider, Provider): """Test speech API provider.""" def __init__(self, lang: str) -> None: @@ -185,7 +187,7 @@ class MockTTS(MockPlatform): {vol.Optional(CONF_LANG, default=DEFAULT_LANG): vol.In(SUPPORT_LANGUAGES)} ) - def __init__(self, provider: MockProvider, **kwargs: Any) -> None: + def __init__(self, provider: MockTTSProvider, **kwargs: Any) -> None: """Initialize.""" super().__init__(**kwargs) self._provider = provider @@ -202,7 +204,7 @@ class MockTTS(MockPlatform): async def mock_setup( hass: HomeAssistant, - mock_provider: MockProvider, + mock_provider: MockTTSProvider, ) -> None: """Set up a test provider.""" mock_integration(hass, MockModule(domain=TEST_DOMAIN)) diff --git a/tests/components/tts/conftest.py b/tests/components/tts/conftest.py index 91ddd7742af..16c24f006d7 100644 --- a/tests/components/tts/conftest.py +++ b/tests/components/tts/conftest.py @@ -17,9 +17,9 @@ from homeassistant.core import HomeAssistant from .common import ( DEFAULT_LANG, TEST_DOMAIN, - MockProvider, MockTTS, MockTTSEntity, + MockTTSProvider, mock_config_entry_setup, mock_setup, ) @@ -67,9 +67,9 @@ async def mock_tts(hass: HomeAssistant, mock_provider) -> None: @pytest.fixture -def mock_provider() -> MockProvider: +def mock_provider() -> MockTTSProvider: """Test TTS provider.""" - return MockProvider(DEFAULT_LANG) + return MockTTSProvider(DEFAULT_LANG) @pytest.fixture @@ -106,7 +106,7 @@ def config_flow_fixture( async def setup_fixture( hass: HomeAssistant, request: pytest.FixtureRequest, - mock_provider: MockProvider, + mock_provider: MockTTSProvider, mock_tts_entity: MockTTSEntity, ) -> None: """Set up the test environment.""" diff --git a/tests/components/tts/test_init.py b/tests/components/tts/test_init.py index 1417fcda2a7..cf04fbb175b 100644 --- a/tests/components/tts/test_init.py +++ b/tests/components/tts/test_init.py @@ -30,9 +30,9 @@ from .common import ( DEFAULT_LANG, SUPPORT_LANGUAGES, TEST_DOMAIN, - MockProvider, MockTTS, MockTTSEntity, + MockTTSProvider, get_media_source_url, mock_config_entry_setup, mock_setup, @@ -220,7 +220,7 @@ async def test_service( @pytest.mark.parametrize( ("mock_provider", "mock_tts_entity"), - [(MockProvider("de_DE"), MockTTSEntity("de_DE"))], + [(MockTTSProvider("de_DE"), MockTTSEntity("de_DE"))], ) @pytest.mark.parametrize( ("setup", "tts_service", "service_data", "expected_url_suffix"), @@ -281,7 +281,7 @@ async def test_service_default_language( @pytest.mark.parametrize( ("mock_provider", "mock_tts_entity"), - [(MockProvider("en_US"), MockTTSEntity("en_US"))], + [(MockTTSProvider("en_US"), MockTTSEntity("en_US"))], ) @pytest.mark.parametrize( ("setup", "tts_service", "service_data", "expected_url_suffix"), @@ -511,7 +511,7 @@ async def test_service_options( ).is_file() -class MockProviderWithDefaults(MockProvider): +class MockProviderWithDefaults(MockTTSProvider): """Mock provider with default options.""" @property @@ -854,7 +854,7 @@ async def test_service_receive_voice( @pytest.mark.parametrize( ("mock_provider", "mock_tts_entity"), - [(MockProvider("de_DE"), MockTTSEntity("de_DE"))], + [(MockTTSProvider("de_DE"), MockTTSEntity("de_DE"))], ) @pytest.mark.parametrize( ("setup", "tts_service", "service_data", "expected_url_suffix"), @@ -1015,7 +1015,7 @@ async def test_service_without_cache( ).is_file() -class MockProviderBoom(MockProvider): +class MockProviderBoom(MockTTSProvider): """Mock provider that blows up.""" def get_tts_audio( @@ -1041,7 +1041,7 @@ class MockEntityBoom(MockTTSEntity): async def test_setup_legacy_cache_dir( hass: HomeAssistant, mock_tts_cache_dir: Path, - mock_provider: MockProvider, + mock_provider: MockTTSProvider, ) -> None: """Set up a TTS platform with cache and call service without cache.""" calls = async_mock_service(hass, DOMAIN_MP, SERVICE_PLAY_MEDIA) @@ -1106,7 +1106,7 @@ async def test_setup_cache_dir( await hass.async_block_till_done() -class MockProviderEmpty(MockProvider): +class MockProviderEmpty(MockTTSProvider): """Mock provider with empty get_tts_audio.""" def get_tts_audio( @@ -1178,7 +1178,7 @@ async def test_service_get_tts_error( async def test_load_cache_legacy_retrieve_without_mem_cache( hass: HomeAssistant, - mock_provider: MockProvider, + mock_provider: MockTTSProvider, mock_tts_cache_dir: Path, hass_client: ClientSessionGenerator, ) -> None: @@ -1426,7 +1426,7 @@ async def test_legacy_fetching_in_async( """Test async fetching of data for a legacy provider.""" tts_audio: asyncio.Future[bytes] = asyncio.Future() - class ProviderWithAsyncFetching(MockProvider): + class ProviderWithAsyncFetching(MockTTSProvider): """Provider that supports audio output option.""" @property @@ -1662,8 +1662,8 @@ async def test_ws_list_engines_deprecated( also provides tts entities. """ - mock_provider = MockProvider(DEFAULT_LANG) - mock_provider_2 = MockProvider(DEFAULT_LANG) + mock_provider = MockTTSProvider(DEFAULT_LANG) + mock_provider_2 = MockTTSProvider(DEFAULT_LANG) mock_integration(hass, MockModule(domain="test")) mock_platform(hass, "test.tts", MockTTS(mock_provider)) mock_integration(hass, MockModule(domain="test_2")) @@ -1910,7 +1910,7 @@ async def test_ttsentity_subclass_properties( async def test_default_engine_prefer_entity( hass: HomeAssistant, mock_tts_entity: MockTTSEntity, - mock_provider: MockProvider, + mock_provider: MockTTSProvider, ) -> None: """Test async_default_engine. @@ -1941,7 +1941,7 @@ async def test_default_engine_prefer_entity( ) async def test_default_engine_prefer_cloud_entity( hass: HomeAssistant, - mock_provider: MockProvider, + mock_provider: MockTTSProvider, config_flow_test_domains: str, ) -> None: """Test async_default_engine. diff --git a/tests/components/tts/test_legacy.py b/tests/components/tts/test_legacy.py index 0d7f99e8cd1..22e8ac35f16 100644 --- a/tests/components/tts/test_legacy.py +++ b/tests/components/tts/test_legacy.py @@ -17,7 +17,7 @@ from homeassistant.helpers.discovery import async_load_platform from homeassistant.helpers.typing import ConfigType, DiscoveryInfoType from homeassistant.setup import async_setup_component -from .common import SUPPORT_LANGUAGES, MockProvider, MockTTS +from .common import SUPPORT_LANGUAGES, MockTTS, MockTTSProvider from tests.common import ( MockModule, @@ -75,7 +75,9 @@ async def test_invalid_platform( async def test_platform_setup_without_provider( - hass: HomeAssistant, caplog: pytest.LogCaptureFixture, mock_provider: MockProvider + hass: HomeAssistant, + caplog: pytest.LogCaptureFixture, + mock_provider: MockTTSProvider, ) -> None: """Test platform setup without provider returned.""" @@ -109,7 +111,7 @@ async def test_platform_setup_without_provider( async def test_platform_setup_with_error( hass: HomeAssistant, caplog: pytest.LogCaptureFixture, - mock_provider: MockProvider, + mock_provider: MockTTSProvider, ) -> None: """Test platform setup with an error during setup.""" diff --git a/tests/components/tts/test_media_source.py b/tests/components/tts/test_media_source.py index 4c10d8f0b08..ba856fd9622 100644 --- a/tests/components/tts/test_media_source.py +++ b/tests/components/tts/test_media_source.py @@ -12,8 +12,8 @@ from homeassistant.setup import async_setup_component from .common import ( DEFAULT_LANG, - MockProvider, MockTTSEntity, + MockTTSProvider, mock_config_entry_setup, mock_setup, retrieve_media, @@ -28,7 +28,7 @@ class MSEntity(MockTTSEntity): get_tts_audio = MagicMock(return_value=("mp3", b"")) -class MSProvider(MockProvider): +class MSProvider(MockTTSProvider): """Test speech API provider.""" get_tts_audio = MagicMock(return_value=("mp3", b""))