mirror of
https://github.com/home-assistant/core.git
synced 2025-04-24 01:08:12 +00:00
Deduplicate TTS mocks (#124773)
This commit is contained in:
parent
38ef216894
commit
cff4e46694
@ -23,7 +23,7 @@ from homeassistant.components.assist_pipeline.pipeline import (
|
||||
)
|
||||
from homeassistant.config_entries import ConfigEntry, ConfigFlow
|
||||
from homeassistant.const import Platform
|
||||
from homeassistant.core import HomeAssistant, callback
|
||||
from homeassistant.core import HomeAssistant
|
||||
from homeassistant.helpers import device_registry as dr
|
||||
from homeassistant.helpers.entity_platform import AddEntitiesCallback
|
||||
from homeassistant.setup import async_setup_component
|
||||
@ -37,6 +37,7 @@ from tests.common import (
|
||||
mock_platform,
|
||||
)
|
||||
from tests.components.stt.common import MockSTTProvider, MockSTTProviderEntity
|
||||
from tests.components.tts.common import MockTTSProvider
|
||||
|
||||
_TRANSCRIPT = "test transcript"
|
||||
|
||||
@ -48,46 +49,6 @@ def mock_tts_cache_dir_autouse(mock_tts_cache_dir: Path) -> None:
|
||||
"""Mock the TTS cache dir with empty dir."""
|
||||
|
||||
|
||||
class MockTTSProvider(tts.Provider):
|
||||
"""Mock TTS provider."""
|
||||
|
||||
name = "Test"
|
||||
_supported_languages = ["en-US"]
|
||||
_supported_voices = {
|
||||
"en-US": [
|
||||
tts.Voice("james_earl_jones", "James Earl Jones"),
|
||||
tts.Voice("fran_drescher", "Fran Drescher"),
|
||||
]
|
||||
}
|
||||
_supported_options = ["voice", "age", tts.ATTR_AUDIO_OUTPUT]
|
||||
|
||||
@property
|
||||
def default_language(self) -> str:
|
||||
"""Return the default language."""
|
||||
return "en"
|
||||
|
||||
@property
|
||||
def supported_languages(self) -> list[str]:
|
||||
"""Return list of supported languages."""
|
||||
return self._supported_languages
|
||||
|
||||
@callback
|
||||
def async_get_supported_voices(self, language: str) -> list[tts.Voice] | None:
|
||||
"""Return a list of supported voices for a language."""
|
||||
return self._supported_voices.get(language)
|
||||
|
||||
@property
|
||||
def supported_options(self) -> list[str]:
|
||||
"""Return list of supported options like voice, emotions."""
|
||||
return self._supported_options
|
||||
|
||||
def get_tts_audio(
|
||||
self, message: str, language: str, options: dict[str, Any]
|
||||
) -> tts.TtsAudioType:
|
||||
"""Load TTS data."""
|
||||
return ("mp3", b"")
|
||||
|
||||
|
||||
class MockTTSPlatform(MockPlatform):
|
||||
"""A mock TTS platform."""
|
||||
|
||||
@ -102,7 +63,9 @@ class MockTTSPlatform(MockPlatform):
|
||||
@pytest.fixture
|
||||
async def mock_tts_provider() -> MockTTSProvider:
|
||||
"""Mock TTS provider."""
|
||||
return MockTTSProvider()
|
||||
provider = MockTTSProvider("en")
|
||||
provider._supported_languages = ["en-US"]
|
||||
return provider
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
|
@ -130,6 +130,8 @@ class BaseProvider:
|
||||
def __init__(self, lang: str) -> None:
|
||||
"""Initialize test provider."""
|
||||
self._lang = lang
|
||||
self._supported_languages = SUPPORT_LANGUAGES
|
||||
self._supported_options = ["voice", "age"]
|
||||
|
||||
@property
|
||||
def default_language(self) -> str:
|
||||
@ -139,7 +141,7 @@ class BaseProvider:
|
||||
@property
|
||||
def supported_languages(self) -> list[str]:
|
||||
"""Return list of supported languages."""
|
||||
return SUPPORT_LANGUAGES
|
||||
return self._supported_languages
|
||||
|
||||
@callback
|
||||
def async_get_supported_voices(self, language: str) -> list[Voice] | None:
|
||||
@ -154,7 +156,7 @@ class BaseProvider:
|
||||
@property
|
||||
def supported_options(self) -> list[str]:
|
||||
"""Return list of supported options like voice, emotions."""
|
||||
return ["voice", "age"]
|
||||
return self._supported_options
|
||||
|
||||
def get_tts_audio(
|
||||
self, message: str, language: str, options: dict[str, Any]
|
||||
@ -163,7 +165,7 @@ class BaseProvider:
|
||||
return ("mp3", b"")
|
||||
|
||||
|
||||
class MockProvider(BaseProvider, Provider):
|
||||
class MockTTSProvider(BaseProvider, Provider):
|
||||
"""Test speech API provider."""
|
||||
|
||||
def __init__(self, lang: str) -> None:
|
||||
@ -185,7 +187,7 @@ class MockTTS(MockPlatform):
|
||||
{vol.Optional(CONF_LANG, default=DEFAULT_LANG): vol.In(SUPPORT_LANGUAGES)}
|
||||
)
|
||||
|
||||
def __init__(self, provider: MockProvider, **kwargs: Any) -> None:
|
||||
def __init__(self, provider: MockTTSProvider, **kwargs: Any) -> None:
|
||||
"""Initialize."""
|
||||
super().__init__(**kwargs)
|
||||
self._provider = provider
|
||||
@ -202,7 +204,7 @@ class MockTTS(MockPlatform):
|
||||
|
||||
async def mock_setup(
|
||||
hass: HomeAssistant,
|
||||
mock_provider: MockProvider,
|
||||
mock_provider: MockTTSProvider,
|
||||
) -> None:
|
||||
"""Set up a test provider."""
|
||||
mock_integration(hass, MockModule(domain=TEST_DOMAIN))
|
||||
|
@ -17,9 +17,9 @@ from homeassistant.core import HomeAssistant
|
||||
from .common import (
|
||||
DEFAULT_LANG,
|
||||
TEST_DOMAIN,
|
||||
MockProvider,
|
||||
MockTTS,
|
||||
MockTTSEntity,
|
||||
MockTTSProvider,
|
||||
mock_config_entry_setup,
|
||||
mock_setup,
|
||||
)
|
||||
@ -67,9 +67,9 @@ async def mock_tts(hass: HomeAssistant, mock_provider) -> None:
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_provider() -> MockProvider:
|
||||
def mock_provider() -> MockTTSProvider:
|
||||
"""Test TTS provider."""
|
||||
return MockProvider(DEFAULT_LANG)
|
||||
return MockTTSProvider(DEFAULT_LANG)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@ -106,7 +106,7 @@ def config_flow_fixture(
|
||||
async def setup_fixture(
|
||||
hass: HomeAssistant,
|
||||
request: pytest.FixtureRequest,
|
||||
mock_provider: MockProvider,
|
||||
mock_provider: MockTTSProvider,
|
||||
mock_tts_entity: MockTTSEntity,
|
||||
) -> None:
|
||||
"""Set up the test environment."""
|
||||
|
@ -30,9 +30,9 @@ from .common import (
|
||||
DEFAULT_LANG,
|
||||
SUPPORT_LANGUAGES,
|
||||
TEST_DOMAIN,
|
||||
MockProvider,
|
||||
MockTTS,
|
||||
MockTTSEntity,
|
||||
MockTTSProvider,
|
||||
get_media_source_url,
|
||||
mock_config_entry_setup,
|
||||
mock_setup,
|
||||
@ -220,7 +220,7 @@ async def test_service(
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("mock_provider", "mock_tts_entity"),
|
||||
[(MockProvider("de_DE"), MockTTSEntity("de_DE"))],
|
||||
[(MockTTSProvider("de_DE"), MockTTSEntity("de_DE"))],
|
||||
)
|
||||
@pytest.mark.parametrize(
|
||||
("setup", "tts_service", "service_data", "expected_url_suffix"),
|
||||
@ -281,7 +281,7 @@ async def test_service_default_language(
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("mock_provider", "mock_tts_entity"),
|
||||
[(MockProvider("en_US"), MockTTSEntity("en_US"))],
|
||||
[(MockTTSProvider("en_US"), MockTTSEntity("en_US"))],
|
||||
)
|
||||
@pytest.mark.parametrize(
|
||||
("setup", "tts_service", "service_data", "expected_url_suffix"),
|
||||
@ -511,7 +511,7 @@ async def test_service_options(
|
||||
).is_file()
|
||||
|
||||
|
||||
class MockProviderWithDefaults(MockProvider):
|
||||
class MockProviderWithDefaults(MockTTSProvider):
|
||||
"""Mock provider with default options."""
|
||||
|
||||
@property
|
||||
@ -854,7 +854,7 @@ async def test_service_receive_voice(
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("mock_provider", "mock_tts_entity"),
|
||||
[(MockProvider("de_DE"), MockTTSEntity("de_DE"))],
|
||||
[(MockTTSProvider("de_DE"), MockTTSEntity("de_DE"))],
|
||||
)
|
||||
@pytest.mark.parametrize(
|
||||
("setup", "tts_service", "service_data", "expected_url_suffix"),
|
||||
@ -1015,7 +1015,7 @@ async def test_service_without_cache(
|
||||
).is_file()
|
||||
|
||||
|
||||
class MockProviderBoom(MockProvider):
|
||||
class MockProviderBoom(MockTTSProvider):
|
||||
"""Mock provider that blows up."""
|
||||
|
||||
def get_tts_audio(
|
||||
@ -1041,7 +1041,7 @@ class MockEntityBoom(MockTTSEntity):
|
||||
async def test_setup_legacy_cache_dir(
|
||||
hass: HomeAssistant,
|
||||
mock_tts_cache_dir: Path,
|
||||
mock_provider: MockProvider,
|
||||
mock_provider: MockTTSProvider,
|
||||
) -> None:
|
||||
"""Set up a TTS platform with cache and call service without cache."""
|
||||
calls = async_mock_service(hass, DOMAIN_MP, SERVICE_PLAY_MEDIA)
|
||||
@ -1106,7 +1106,7 @@ async def test_setup_cache_dir(
|
||||
await hass.async_block_till_done()
|
||||
|
||||
|
||||
class MockProviderEmpty(MockProvider):
|
||||
class MockProviderEmpty(MockTTSProvider):
|
||||
"""Mock provider with empty get_tts_audio."""
|
||||
|
||||
def get_tts_audio(
|
||||
@ -1178,7 +1178,7 @@ async def test_service_get_tts_error(
|
||||
|
||||
async def test_load_cache_legacy_retrieve_without_mem_cache(
|
||||
hass: HomeAssistant,
|
||||
mock_provider: MockProvider,
|
||||
mock_provider: MockTTSProvider,
|
||||
mock_tts_cache_dir: Path,
|
||||
hass_client: ClientSessionGenerator,
|
||||
) -> None:
|
||||
@ -1426,7 +1426,7 @@ async def test_legacy_fetching_in_async(
|
||||
"""Test async fetching of data for a legacy provider."""
|
||||
tts_audio: asyncio.Future[bytes] = asyncio.Future()
|
||||
|
||||
class ProviderWithAsyncFetching(MockProvider):
|
||||
class ProviderWithAsyncFetching(MockTTSProvider):
|
||||
"""Provider that supports audio output option."""
|
||||
|
||||
@property
|
||||
@ -1662,8 +1662,8 @@ async def test_ws_list_engines_deprecated(
|
||||
also provides tts entities.
|
||||
"""
|
||||
|
||||
mock_provider = MockProvider(DEFAULT_LANG)
|
||||
mock_provider_2 = MockProvider(DEFAULT_LANG)
|
||||
mock_provider = MockTTSProvider(DEFAULT_LANG)
|
||||
mock_provider_2 = MockTTSProvider(DEFAULT_LANG)
|
||||
mock_integration(hass, MockModule(domain="test"))
|
||||
mock_platform(hass, "test.tts", MockTTS(mock_provider))
|
||||
mock_integration(hass, MockModule(domain="test_2"))
|
||||
@ -1910,7 +1910,7 @@ async def test_ttsentity_subclass_properties(
|
||||
async def test_default_engine_prefer_entity(
|
||||
hass: HomeAssistant,
|
||||
mock_tts_entity: MockTTSEntity,
|
||||
mock_provider: MockProvider,
|
||||
mock_provider: MockTTSProvider,
|
||||
) -> None:
|
||||
"""Test async_default_engine.
|
||||
|
||||
@ -1941,7 +1941,7 @@ async def test_default_engine_prefer_entity(
|
||||
)
|
||||
async def test_default_engine_prefer_cloud_entity(
|
||||
hass: HomeAssistant,
|
||||
mock_provider: MockProvider,
|
||||
mock_provider: MockTTSProvider,
|
||||
config_flow_test_domains: str,
|
||||
) -> None:
|
||||
"""Test async_default_engine.
|
||||
|
@ -17,7 +17,7 @@ from homeassistant.helpers.discovery import async_load_platform
|
||||
from homeassistant.helpers.typing import ConfigType, DiscoveryInfoType
|
||||
from homeassistant.setup import async_setup_component
|
||||
|
||||
from .common import SUPPORT_LANGUAGES, MockProvider, MockTTS
|
||||
from .common import SUPPORT_LANGUAGES, MockTTS, MockTTSProvider
|
||||
|
||||
from tests.common import (
|
||||
MockModule,
|
||||
@ -75,7 +75,9 @@ async def test_invalid_platform(
|
||||
|
||||
|
||||
async def test_platform_setup_without_provider(
|
||||
hass: HomeAssistant, caplog: pytest.LogCaptureFixture, mock_provider: MockProvider
|
||||
hass: HomeAssistant,
|
||||
caplog: pytest.LogCaptureFixture,
|
||||
mock_provider: MockTTSProvider,
|
||||
) -> None:
|
||||
"""Test platform setup without provider returned."""
|
||||
|
||||
@ -109,7 +111,7 @@ async def test_platform_setup_without_provider(
|
||||
async def test_platform_setup_with_error(
|
||||
hass: HomeAssistant,
|
||||
caplog: pytest.LogCaptureFixture,
|
||||
mock_provider: MockProvider,
|
||||
mock_provider: MockTTSProvider,
|
||||
) -> None:
|
||||
"""Test platform setup with an error during setup."""
|
||||
|
||||
|
@ -12,8 +12,8 @@ from homeassistant.setup import async_setup_component
|
||||
|
||||
from .common import (
|
||||
DEFAULT_LANG,
|
||||
MockProvider,
|
||||
MockTTSEntity,
|
||||
MockTTSProvider,
|
||||
mock_config_entry_setup,
|
||||
mock_setup,
|
||||
retrieve_media,
|
||||
@ -28,7 +28,7 @@ class MSEntity(MockTTSEntity):
|
||||
get_tts_audio = MagicMock(return_value=("mp3", b""))
|
||||
|
||||
|
||||
class MSProvider(MockProvider):
|
||||
class MSProvider(MockTTSProvider):
|
||||
"""Test speech API provider."""
|
||||
|
||||
get_tts_audio = MagicMock(return_value=("mp3", b""))
|
||||
|
Loading…
x
Reference in New Issue
Block a user