Deduplicate TTS mocks (#124773)

This commit is contained in:
Erik Montnemery 2024-08-28 13:48:49 +02:00 committed by GitHub
parent 38ef216894
commit cff4e46694
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 37 additions and 70 deletions

View File

@ -23,7 +23,7 @@ from homeassistant.components.assist_pipeline.pipeline import (
) )
from homeassistant.config_entries import ConfigEntry, ConfigFlow from homeassistant.config_entries import ConfigEntry, ConfigFlow
from homeassistant.const import Platform from homeassistant.const import Platform
from homeassistant.core import HomeAssistant, callback from homeassistant.core import HomeAssistant
from homeassistant.helpers import device_registry as dr from homeassistant.helpers import device_registry as dr
from homeassistant.helpers.entity_platform import AddEntitiesCallback from homeassistant.helpers.entity_platform import AddEntitiesCallback
from homeassistant.setup import async_setup_component from homeassistant.setup import async_setup_component
@ -37,6 +37,7 @@ from tests.common import (
mock_platform, mock_platform,
) )
from tests.components.stt.common import MockSTTProvider, MockSTTProviderEntity from tests.components.stt.common import MockSTTProvider, MockSTTProviderEntity
from tests.components.tts.common import MockTTSProvider
_TRANSCRIPT = "test transcript" _TRANSCRIPT = "test transcript"
@ -48,46 +49,6 @@ def mock_tts_cache_dir_autouse(mock_tts_cache_dir: Path) -> None:
"""Mock the TTS cache dir with empty dir.""" """Mock the TTS cache dir with empty dir."""
class MockTTSProvider(tts.Provider):
"""Mock TTS provider."""
name = "Test"
_supported_languages = ["en-US"]
_supported_voices = {
"en-US": [
tts.Voice("james_earl_jones", "James Earl Jones"),
tts.Voice("fran_drescher", "Fran Drescher"),
]
}
_supported_options = ["voice", "age", tts.ATTR_AUDIO_OUTPUT]
@property
def default_language(self) -> str:
"""Return the default language."""
return "en"
@property
def supported_languages(self) -> list[str]:
"""Return list of supported languages."""
return self._supported_languages
@callback
def async_get_supported_voices(self, language: str) -> list[tts.Voice] | None:
"""Return a list of supported voices for a language."""
return self._supported_voices.get(language)
@property
def supported_options(self) -> list[str]:
"""Return list of supported options like voice, emotions."""
return self._supported_options
def get_tts_audio(
self, message: str, language: str, options: dict[str, Any]
) -> tts.TtsAudioType:
"""Load TTS data."""
return ("mp3", b"")
class MockTTSPlatform(MockPlatform): class MockTTSPlatform(MockPlatform):
"""A mock TTS platform.""" """A mock TTS platform."""
@ -102,7 +63,9 @@ class MockTTSPlatform(MockPlatform):
@pytest.fixture @pytest.fixture
async def mock_tts_provider() -> MockTTSProvider: async def mock_tts_provider() -> MockTTSProvider:
"""Mock TTS provider.""" """Mock TTS provider."""
return MockTTSProvider() provider = MockTTSProvider("en")
provider._supported_languages = ["en-US"]
return provider
@pytest.fixture @pytest.fixture

View File

@ -130,6 +130,8 @@ class BaseProvider:
def __init__(self, lang: str) -> None: def __init__(self, lang: str) -> None:
"""Initialize test provider.""" """Initialize test provider."""
self._lang = lang self._lang = lang
self._supported_languages = SUPPORT_LANGUAGES
self._supported_options = ["voice", "age"]
@property @property
def default_language(self) -> str: def default_language(self) -> str:
@ -139,7 +141,7 @@ class BaseProvider:
@property @property
def supported_languages(self) -> list[str]: def supported_languages(self) -> list[str]:
"""Return list of supported languages.""" """Return list of supported languages."""
return SUPPORT_LANGUAGES return self._supported_languages
@callback @callback
def async_get_supported_voices(self, language: str) -> list[Voice] | None: def async_get_supported_voices(self, language: str) -> list[Voice] | None:
@ -154,7 +156,7 @@ class BaseProvider:
@property @property
def supported_options(self) -> list[str]: def supported_options(self) -> list[str]:
"""Return list of supported options like voice, emotions.""" """Return list of supported options like voice, emotions."""
return ["voice", "age"] return self._supported_options
def get_tts_audio( def get_tts_audio(
self, message: str, language: str, options: dict[str, Any] self, message: str, language: str, options: dict[str, Any]
@ -163,7 +165,7 @@ class BaseProvider:
return ("mp3", b"") return ("mp3", b"")
class MockProvider(BaseProvider, Provider): class MockTTSProvider(BaseProvider, Provider):
"""Test speech API provider.""" """Test speech API provider."""
def __init__(self, lang: str) -> None: def __init__(self, lang: str) -> None:
@ -185,7 +187,7 @@ class MockTTS(MockPlatform):
{vol.Optional(CONF_LANG, default=DEFAULT_LANG): vol.In(SUPPORT_LANGUAGES)} {vol.Optional(CONF_LANG, default=DEFAULT_LANG): vol.In(SUPPORT_LANGUAGES)}
) )
def __init__(self, provider: MockProvider, **kwargs: Any) -> None: def __init__(self, provider: MockTTSProvider, **kwargs: Any) -> None:
"""Initialize.""" """Initialize."""
super().__init__(**kwargs) super().__init__(**kwargs)
self._provider = provider self._provider = provider
@ -202,7 +204,7 @@ class MockTTS(MockPlatform):
async def mock_setup( async def mock_setup(
hass: HomeAssistant, hass: HomeAssistant,
mock_provider: MockProvider, mock_provider: MockTTSProvider,
) -> None: ) -> None:
"""Set up a test provider.""" """Set up a test provider."""
mock_integration(hass, MockModule(domain=TEST_DOMAIN)) mock_integration(hass, MockModule(domain=TEST_DOMAIN))

View File

@ -17,9 +17,9 @@ from homeassistant.core import HomeAssistant
from .common import ( from .common import (
DEFAULT_LANG, DEFAULT_LANG,
TEST_DOMAIN, TEST_DOMAIN,
MockProvider,
MockTTS, MockTTS,
MockTTSEntity, MockTTSEntity,
MockTTSProvider,
mock_config_entry_setup, mock_config_entry_setup,
mock_setup, mock_setup,
) )
@ -67,9 +67,9 @@ async def mock_tts(hass: HomeAssistant, mock_provider) -> None:
@pytest.fixture @pytest.fixture
def mock_provider() -> MockProvider: def mock_provider() -> MockTTSProvider:
"""Test TTS provider.""" """Test TTS provider."""
return MockProvider(DEFAULT_LANG) return MockTTSProvider(DEFAULT_LANG)
@pytest.fixture @pytest.fixture
@ -106,7 +106,7 @@ def config_flow_fixture(
async def setup_fixture( async def setup_fixture(
hass: HomeAssistant, hass: HomeAssistant,
request: pytest.FixtureRequest, request: pytest.FixtureRequest,
mock_provider: MockProvider, mock_provider: MockTTSProvider,
mock_tts_entity: MockTTSEntity, mock_tts_entity: MockTTSEntity,
) -> None: ) -> None:
"""Set up the test environment.""" """Set up the test environment."""

View File

@ -30,9 +30,9 @@ from .common import (
DEFAULT_LANG, DEFAULT_LANG,
SUPPORT_LANGUAGES, SUPPORT_LANGUAGES,
TEST_DOMAIN, TEST_DOMAIN,
MockProvider,
MockTTS, MockTTS,
MockTTSEntity, MockTTSEntity,
MockTTSProvider,
get_media_source_url, get_media_source_url,
mock_config_entry_setup, mock_config_entry_setup,
mock_setup, mock_setup,
@ -220,7 +220,7 @@ async def test_service(
@pytest.mark.parametrize( @pytest.mark.parametrize(
("mock_provider", "mock_tts_entity"), ("mock_provider", "mock_tts_entity"),
[(MockProvider("de_DE"), MockTTSEntity("de_DE"))], [(MockTTSProvider("de_DE"), MockTTSEntity("de_DE"))],
) )
@pytest.mark.parametrize( @pytest.mark.parametrize(
("setup", "tts_service", "service_data", "expected_url_suffix"), ("setup", "tts_service", "service_data", "expected_url_suffix"),
@ -281,7 +281,7 @@ async def test_service_default_language(
@pytest.mark.parametrize( @pytest.mark.parametrize(
("mock_provider", "mock_tts_entity"), ("mock_provider", "mock_tts_entity"),
[(MockProvider("en_US"), MockTTSEntity("en_US"))], [(MockTTSProvider("en_US"), MockTTSEntity("en_US"))],
) )
@pytest.mark.parametrize( @pytest.mark.parametrize(
("setup", "tts_service", "service_data", "expected_url_suffix"), ("setup", "tts_service", "service_data", "expected_url_suffix"),
@ -511,7 +511,7 @@ async def test_service_options(
).is_file() ).is_file()
class MockProviderWithDefaults(MockProvider): class MockProviderWithDefaults(MockTTSProvider):
"""Mock provider with default options.""" """Mock provider with default options."""
@property @property
@ -854,7 +854,7 @@ async def test_service_receive_voice(
@pytest.mark.parametrize( @pytest.mark.parametrize(
("mock_provider", "mock_tts_entity"), ("mock_provider", "mock_tts_entity"),
[(MockProvider("de_DE"), MockTTSEntity("de_DE"))], [(MockTTSProvider("de_DE"), MockTTSEntity("de_DE"))],
) )
@pytest.mark.parametrize( @pytest.mark.parametrize(
("setup", "tts_service", "service_data", "expected_url_suffix"), ("setup", "tts_service", "service_data", "expected_url_suffix"),
@ -1015,7 +1015,7 @@ async def test_service_without_cache(
).is_file() ).is_file()
class MockProviderBoom(MockProvider): class MockProviderBoom(MockTTSProvider):
"""Mock provider that blows up.""" """Mock provider that blows up."""
def get_tts_audio( def get_tts_audio(
@ -1041,7 +1041,7 @@ class MockEntityBoom(MockTTSEntity):
async def test_setup_legacy_cache_dir( async def test_setup_legacy_cache_dir(
hass: HomeAssistant, hass: HomeAssistant,
mock_tts_cache_dir: Path, mock_tts_cache_dir: Path,
mock_provider: MockProvider, mock_provider: MockTTSProvider,
) -> None: ) -> None:
"""Set up a TTS platform with cache and call service without cache.""" """Set up a TTS platform with cache and call service without cache."""
calls = async_mock_service(hass, DOMAIN_MP, SERVICE_PLAY_MEDIA) calls = async_mock_service(hass, DOMAIN_MP, SERVICE_PLAY_MEDIA)
@ -1106,7 +1106,7 @@ async def test_setup_cache_dir(
await hass.async_block_till_done() await hass.async_block_till_done()
class MockProviderEmpty(MockProvider): class MockProviderEmpty(MockTTSProvider):
"""Mock provider with empty get_tts_audio.""" """Mock provider with empty get_tts_audio."""
def get_tts_audio( def get_tts_audio(
@ -1178,7 +1178,7 @@ async def test_service_get_tts_error(
async def test_load_cache_legacy_retrieve_without_mem_cache( async def test_load_cache_legacy_retrieve_without_mem_cache(
hass: HomeAssistant, hass: HomeAssistant,
mock_provider: MockProvider, mock_provider: MockTTSProvider,
mock_tts_cache_dir: Path, mock_tts_cache_dir: Path,
hass_client: ClientSessionGenerator, hass_client: ClientSessionGenerator,
) -> None: ) -> None:
@ -1426,7 +1426,7 @@ async def test_legacy_fetching_in_async(
"""Test async fetching of data for a legacy provider.""" """Test async fetching of data for a legacy provider."""
tts_audio: asyncio.Future[bytes] = asyncio.Future() tts_audio: asyncio.Future[bytes] = asyncio.Future()
class ProviderWithAsyncFetching(MockProvider): class ProviderWithAsyncFetching(MockTTSProvider):
"""Provider that supports audio output option.""" """Provider that supports audio output option."""
@property @property
@ -1662,8 +1662,8 @@ async def test_ws_list_engines_deprecated(
also provides tts entities. also provides tts entities.
""" """
mock_provider = MockProvider(DEFAULT_LANG) mock_provider = MockTTSProvider(DEFAULT_LANG)
mock_provider_2 = MockProvider(DEFAULT_LANG) mock_provider_2 = MockTTSProvider(DEFAULT_LANG)
mock_integration(hass, MockModule(domain="test")) mock_integration(hass, MockModule(domain="test"))
mock_platform(hass, "test.tts", MockTTS(mock_provider)) mock_platform(hass, "test.tts", MockTTS(mock_provider))
mock_integration(hass, MockModule(domain="test_2")) mock_integration(hass, MockModule(domain="test_2"))
@ -1910,7 +1910,7 @@ async def test_ttsentity_subclass_properties(
async def test_default_engine_prefer_entity( async def test_default_engine_prefer_entity(
hass: HomeAssistant, hass: HomeAssistant,
mock_tts_entity: MockTTSEntity, mock_tts_entity: MockTTSEntity,
mock_provider: MockProvider, mock_provider: MockTTSProvider,
) -> None: ) -> None:
"""Test async_default_engine. """Test async_default_engine.
@ -1941,7 +1941,7 @@ async def test_default_engine_prefer_entity(
) )
async def test_default_engine_prefer_cloud_entity( async def test_default_engine_prefer_cloud_entity(
hass: HomeAssistant, hass: HomeAssistant,
mock_provider: MockProvider, mock_provider: MockTTSProvider,
config_flow_test_domains: str, config_flow_test_domains: str,
) -> None: ) -> None:
"""Test async_default_engine. """Test async_default_engine.

View File

@ -17,7 +17,7 @@ from homeassistant.helpers.discovery import async_load_platform
from homeassistant.helpers.typing import ConfigType, DiscoveryInfoType from homeassistant.helpers.typing import ConfigType, DiscoveryInfoType
from homeassistant.setup import async_setup_component from homeassistant.setup import async_setup_component
from .common import SUPPORT_LANGUAGES, MockProvider, MockTTS from .common import SUPPORT_LANGUAGES, MockTTS, MockTTSProvider
from tests.common import ( from tests.common import (
MockModule, MockModule,
@ -75,7 +75,9 @@ async def test_invalid_platform(
async def test_platform_setup_without_provider( async def test_platform_setup_without_provider(
hass: HomeAssistant, caplog: pytest.LogCaptureFixture, mock_provider: MockProvider hass: HomeAssistant,
caplog: pytest.LogCaptureFixture,
mock_provider: MockTTSProvider,
) -> None: ) -> None:
"""Test platform setup without provider returned.""" """Test platform setup without provider returned."""
@ -109,7 +111,7 @@ async def test_platform_setup_without_provider(
async def test_platform_setup_with_error( async def test_platform_setup_with_error(
hass: HomeAssistant, hass: HomeAssistant,
caplog: pytest.LogCaptureFixture, caplog: pytest.LogCaptureFixture,
mock_provider: MockProvider, mock_provider: MockTTSProvider,
) -> None: ) -> None:
"""Test platform setup with an error during setup.""" """Test platform setup with an error during setup."""

View File

@ -12,8 +12,8 @@ from homeassistant.setup import async_setup_component
from .common import ( from .common import (
DEFAULT_LANG, DEFAULT_LANG,
MockProvider,
MockTTSEntity, MockTTSEntity,
MockTTSProvider,
mock_config_entry_setup, mock_config_entry_setup,
mock_setup, mock_setup,
retrieve_media, retrieve_media,
@ -28,7 +28,7 @@ class MSEntity(MockTTSEntity):
get_tts_audio = MagicMock(return_value=("mp3", b"")) get_tts_audio = MagicMock(return_value=("mp3", b""))
class MSProvider(MockProvider): class MSProvider(MockTTSProvider):
"""Test speech API provider.""" """Test speech API provider."""
get_tts_audio = MagicMock(return_value=("mp3", b"")) get_tts_audio = MagicMock(return_value=("mp3", b""))