Deduplicate STT mocks (#124754)

This commit is contained in:
Erik Montnemery 2024-08-28 09:25:56 +02:00 committed by GitHub
parent f9bf7f7e05
commit 1add00a68d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 147 additions and 189 deletions

View File

@ -36,6 +36,7 @@ from tests.common import (
mock_integration,
mock_platform,
)
from tests.components.stt.common import MockSTTProvider, MockSTTProviderEntity
_TRANSCRIPT = "test transcript"
@ -47,67 +48,6 @@ def mock_tts_cache_dir_autouse(mock_tts_cache_dir: Path) -> None:
"""Mock the TTS cache dir with empty dir."""
class BaseProvider:
"""Mock STT provider."""
_supported_languages = ["en-US"]
def __init__(self, text: str) -> None:
"""Init test provider."""
self.text = text
self.received: list[bytes] = []
@property
def supported_languages(self) -> list[str]:
"""Return a list of supported languages."""
return self._supported_languages
@property
def supported_formats(self) -> list[stt.AudioFormats]:
"""Return a list of supported formats."""
return [stt.AudioFormats.WAV]
@property
def supported_codecs(self) -> list[stt.AudioCodecs]:
"""Return a list of supported codecs."""
return [stt.AudioCodecs.PCM]
@property
def supported_bit_rates(self) -> list[stt.AudioBitRates]:
"""Return a list of supported bitrates."""
return [stt.AudioBitRates.BITRATE_16]
@property
def supported_sample_rates(self) -> list[stt.AudioSampleRates]:
"""Return a list of supported samplerates."""
return [stt.AudioSampleRates.SAMPLERATE_16000]
@property
def supported_channels(self) -> list[stt.AudioChannels]:
"""Return a list of supported channels."""
return [stt.AudioChannels.CHANNEL_MONO]
async def async_process_audio_stream(
self, metadata: stt.SpeechMetadata, stream: AsyncIterable[bytes]
) -> stt.SpeechResult:
"""Process an audio stream."""
async for data in stream:
if not data:
break
self.received.append(data)
return stt.SpeechResult(self.text, stt.SpeechResultState.SUCCESS)
class MockSttProvider(BaseProvider, stt.Provider):
"""Mock provider."""
class MockSttProviderEntity(BaseProvider, stt.SpeechToTextEntity):
"""Mock provider entity."""
_attr_name = "Mock STT"
class MockTTSProvider(tts.Provider):
"""Mock TTS provider."""
@ -166,15 +106,17 @@ async def mock_tts_provider() -> MockTTSProvider:
@pytest.fixture
async def mock_stt_provider() -> MockSttProvider:
async def mock_stt_provider() -> MockSTTProvider:
"""Mock STT provider."""
return MockSttProvider(_TRANSCRIPT)
return MockSTTProvider(supported_languages=["en-US"], text=_TRANSCRIPT)
@pytest.fixture
def mock_stt_provider_entity() -> MockSttProviderEntity:
def mock_stt_provider_entity() -> MockSTTProviderEntity:
"""Test provider entity fixture."""
return MockSttProviderEntity(_TRANSCRIPT)
entity = MockSTTProviderEntity(supported_languages=["en-US"], text=_TRANSCRIPT)
entity._attr_name = "Mock STT"
return entity
class MockSttPlatform(MockPlatform):
@ -290,8 +232,8 @@ def config_flow_fixture(hass: HomeAssistant) -> Generator[None]:
@pytest.fixture
async def init_supporting_components(
hass: HomeAssistant,
mock_stt_provider: MockSttProvider,
mock_stt_provider_entity: MockSttProviderEntity,
mock_stt_provider: MockSTTProvider,
mock_stt_provider_entity: MockSTTProviderEntity,
mock_tts_provider: MockTTSProvider,
mock_wake_word_provider_entity: MockWakeWordEntity,
mock_wake_word_provider_entity2: MockWakeWordEntity2,

View File

@ -22,8 +22,8 @@ from homeassistant.setup import async_setup_component
from .conftest import (
BYTES_ONE_SECOND,
MockSttProvider,
MockSttProviderEntity,
MockSTTProvider,
MockSTTProviderEntity,
MockTTSProvider,
MockWakeWordEntity,
make_10ms_chunk,
@ -47,7 +47,7 @@ def process_events(events: list[assist_pipeline.PipelineEvent]) -> list[dict]:
async def test_pipeline_from_audio_stream_auto(
hass: HomeAssistant,
mock_stt_provider_entity: MockSttProviderEntity,
mock_stt_provider_entity: MockSTTProviderEntity,
init_components,
snapshot: SnapshotAssertion,
) -> None:
@ -88,7 +88,7 @@ async def test_pipeline_from_audio_stream_auto(
async def test_pipeline_from_audio_stream_legacy(
hass: HomeAssistant,
hass_ws_client: WebSocketGenerator,
mock_stt_provider: MockSttProvider,
mock_stt_provider: MockSTTProvider,
init_components,
snapshot: SnapshotAssertion,
) -> None:
@ -153,7 +153,7 @@ async def test_pipeline_from_audio_stream_legacy(
async def test_pipeline_from_audio_stream_entity(
hass: HomeAssistant,
hass_ws_client: WebSocketGenerator,
mock_stt_provider_entity: MockSttProviderEntity,
mock_stt_provider_entity: MockSTTProviderEntity,
init_components,
snapshot: SnapshotAssertion,
) -> None:
@ -218,7 +218,7 @@ async def test_pipeline_from_audio_stream_entity(
async def test_pipeline_from_audio_stream_no_stt(
hass: HomeAssistant,
hass_ws_client: WebSocketGenerator,
mock_stt_provider: MockSttProvider,
mock_stt_provider: MockSTTProvider,
init_components,
snapshot: SnapshotAssertion,
) -> None:
@ -281,7 +281,7 @@ async def test_pipeline_from_audio_stream_no_stt(
async def test_pipeline_from_audio_stream_unknown_pipeline(
hass: HomeAssistant,
hass_ws_client: WebSocketGenerator,
mock_stt_provider: MockSttProvider,
mock_stt_provider: MockSTTProvider,
init_components,
snapshot: SnapshotAssertion,
) -> None:
@ -319,7 +319,7 @@ async def test_pipeline_from_audio_stream_unknown_pipeline(
async def test_pipeline_from_audio_stream_wake_word(
hass: HomeAssistant,
mock_stt_provider_entity: MockSttProviderEntity,
mock_stt_provider_entity: MockSTTProviderEntity,
mock_wake_word_provider_entity: MockWakeWordEntity,
init_components,
snapshot: SnapshotAssertion,
@ -395,7 +395,7 @@ async def test_pipeline_from_audio_stream_wake_word(
async def test_pipeline_save_audio(
hass: HomeAssistant,
mock_stt_provider: MockSttProvider,
mock_stt_provider: MockSTTProvider,
mock_wake_word_provider_entity: MockWakeWordEntity,
init_supporting_components,
snapshot: SnapshotAssertion,
@ -474,7 +474,7 @@ async def test_pipeline_save_audio(
async def test_pipeline_saved_audio_with_device_id(
hass: HomeAssistant,
mock_stt_provider: MockSttProvider,
mock_stt_provider: MockSTTProvider,
mock_wake_word_provider_entity: MockWakeWordEntity,
init_supporting_components,
snapshot: SnapshotAssertion,
@ -529,7 +529,7 @@ async def test_pipeline_saved_audio_with_device_id(
async def test_pipeline_saved_audio_write_error(
hass: HomeAssistant,
mock_stt_provider: MockSttProvider,
mock_stt_provider: MockSTTProvider,
mock_wake_word_provider_entity: MockWakeWordEntity,
init_supporting_components,
snapshot: SnapshotAssertion,
@ -578,7 +578,7 @@ async def test_pipeline_saved_audio_write_error(
async def test_pipeline_saved_audio_empty_queue(
hass: HomeAssistant,
mock_stt_provider: MockSttProvider,
mock_stt_provider: MockSTTProvider,
mock_wake_word_provider_entity: MockWakeWordEntity,
init_supporting_components,
snapshot: SnapshotAssertion,
@ -641,7 +641,7 @@ async def test_pipeline_saved_audio_empty_queue(
async def test_wake_word_detection_aborted(
hass: HomeAssistant,
mock_stt_provider: MockSttProvider,
mock_stt_provider: MockSTTProvider,
mock_wake_word_provider_entity: MockWakeWordEntity,
init_components,
pipeline_data: assist_pipeline.pipeline.PipelineData,

View File

@ -26,7 +26,7 @@ from homeassistant.core import HomeAssistant
from homeassistant.setup import async_setup_component
from . import MANY_LANGUAGES
from .conftest import MockSttProviderEntity, MockTTSProvider
from .conftest import MockSTTProviderEntity, MockTTSProvider
from tests.common import flush_store
@ -398,7 +398,7 @@ async def test_default_pipeline_no_stt_tts(
@pytest.mark.usefixtures("init_supporting_components")
async def test_default_pipeline(
hass: HomeAssistant,
mock_stt_provider_entity: MockSttProviderEntity,
mock_stt_provider_entity: MockSTTProviderEntity,
mock_tts_provider: MockTTSProvider,
ha_language: str,
ha_country: str | None,
@ -441,7 +441,7 @@ async def test_default_pipeline(
@pytest.mark.usefixtures("init_supporting_components")
async def test_default_pipeline_unsupported_stt_language(
hass: HomeAssistant, mock_stt_provider_entity: MockSttProviderEntity
hass: HomeAssistant, mock_stt_provider_entity: MockSTTProviderEntity
) -> None:
"""Test async_get_pipeline."""
with patch.object(mock_stt_provider_entity, "_supported_languages", ["smurfish"]):

View File

@ -743,7 +743,7 @@ async def test_stt_stream_failed(
client = await hass_ws_client(hass)
with patch(
"tests.components.assist_pipeline.conftest.MockSttProviderEntity.async_process_audio_stream",
"tests.components.assist_pipeline.conftest.MockSTTProviderEntity.async_process_audio_stream",
side_effect=RuntimeError,
):
await client.send_json_auto_id(

View File

@ -2,11 +2,22 @@
from __future__ import annotations
from collections.abc import Callable, Coroutine
from collections.abc import AsyncIterable, Callable, Coroutine
from pathlib import Path
from typing import Any
from homeassistant.components.stt import Provider
from homeassistant.components.stt import (
AudioBitRates,
AudioChannels,
AudioCodecs,
AudioFormats,
AudioSampleRates,
Provider,
SpeechMetadata,
SpeechResult,
SpeechResultState,
SpeechToTextEntity,
)
from homeassistant.config_entries import ConfigEntry
from homeassistant.core import HomeAssistant
from homeassistant.helpers.entity_platform import AddEntitiesCallback
@ -14,6 +25,80 @@ from homeassistant.helpers.typing import ConfigType, DiscoveryInfoType
from tests.common import MockPlatform, mock_platform
TEST_DOMAIN = "test"
class BaseProvider:
"""Mock STT provider."""
fail_process_audio = False
def __init__(
self, *, supported_languages: list[str] | None = None, text: str = "test_result"
) -> None:
"""Init test provider."""
self._supported_languages = supported_languages or ["de", "de-CH", "en"]
self.calls: list[tuple[SpeechMetadata, AsyncIterable[bytes]]] = []
self.received: list[bytes] = []
self.text = text
@property
def supported_languages(self) -> list[str]:
"""Return a list of supported languages."""
return self._supported_languages
@property
def supported_formats(self) -> list[AudioFormats]:
"""Return a list of supported formats."""
return [AudioFormats.WAV, AudioFormats.OGG]
@property
def supported_codecs(self) -> list[AudioCodecs]:
"""Return a list of supported codecs."""
return [AudioCodecs.PCM, AudioCodecs.OPUS]
@property
def supported_bit_rates(self) -> list[AudioBitRates]:
"""Return a list of supported bitrates."""
return [AudioBitRates.BITRATE_16]
@property
def supported_sample_rates(self) -> list[AudioSampleRates]:
"""Return a list of supported samplerates."""
return [AudioSampleRates.SAMPLERATE_16000]
@property
def supported_channels(self) -> list[AudioChannels]:
"""Return a list of supported channels."""
return [AudioChannels.CHANNEL_MONO]
async def async_process_audio_stream(
self, metadata: SpeechMetadata, stream: AsyncIterable[bytes]
) -> SpeechResult:
"""Process an audio stream."""
self.calls.append((metadata, stream))
async for data in stream:
if not data:
break
self.received.append(data)
if self.fail_process_audio:
return SpeechResult(None, SpeechResultState.ERROR)
return SpeechResult(self.text, SpeechResultState.SUCCESS)
class MockSTTProvider(BaseProvider, Provider):
"""Mock provider."""
url_path = TEST_DOMAIN
class MockSTTProviderEntity(BaseProvider, SpeechToTextEntity):
"""Mock provider entity."""
url_path = "stt.test"
_attr_name = "test"
class MockSTTPlatform(MockPlatform):
"""Help to set up test stt service."""

View File

@ -1,6 +1,6 @@
"""Test STT component setup."""
from collections.abc import AsyncIterable, Generator, Iterable
from collections.abc import Generator, Iterable
from contextlib import ExitStack
from http import HTTPStatus
from pathlib import Path
@ -10,16 +10,6 @@ import pytest
from homeassistant.components.stt import (
DOMAIN,
AudioBitRates,
AudioChannels,
AudioCodecs,
AudioFormats,
AudioSampleRates,
Provider,
SpeechMetadata,
SpeechResult,
SpeechResultState,
SpeechToTextEntity,
async_default_engine,
async_get_provider,
async_get_speech_to_text_engine,
@ -29,7 +19,13 @@ from homeassistant.core import HomeAssistant, State
from homeassistant.helpers.entity_platform import AddEntitiesCallback
from homeassistant.setup import async_setup_component
from .common import mock_stt_entity_platform, mock_stt_platform
from .common import (
TEST_DOMAIN,
MockSTTProvider,
MockSTTProviderEntity,
mock_stt_entity_platform,
mock_stt_platform,
)
from tests.common import (
MockConfigEntry,
@ -41,82 +37,17 @@ from tests.common import (
)
from tests.typing import ClientSessionGenerator, WebSocketGenerator
TEST_DOMAIN = "test"
class BaseProvider:
"""Mock provider."""
fail_process_audio = False
def __init__(self) -> None:
"""Init test provider."""
self.calls: list[tuple[SpeechMetadata, AsyncIterable[bytes]]] = []
@property
def supported_languages(self) -> list[str]:
"""Return a list of supported languages."""
return ["de", "de-CH", "en"]
@property
def supported_formats(self) -> list[AudioFormats]:
"""Return a list of supported formats."""
return [AudioFormats.WAV, AudioFormats.OGG]
@property
def supported_codecs(self) -> list[AudioCodecs]:
"""Return a list of supported codecs."""
return [AudioCodecs.PCM, AudioCodecs.OPUS]
@property
def supported_bit_rates(self) -> list[AudioBitRates]:
"""Return a list of supported bitrates."""
return [AudioBitRates.BITRATE_16]
@property
def supported_sample_rates(self) -> list[AudioSampleRates]:
"""Return a list of supported samplerates."""
return [AudioSampleRates.SAMPLERATE_16000]
@property
def supported_channels(self) -> list[AudioChannels]:
"""Return a list of supported channels."""
return [AudioChannels.CHANNEL_MONO]
async def async_process_audio_stream(
self, metadata: SpeechMetadata, stream: AsyncIterable[bytes]
) -> SpeechResult:
"""Process an audio stream."""
self.calls.append((metadata, stream))
if self.fail_process_audio:
return SpeechResult(None, SpeechResultState.ERROR)
return SpeechResult("test_result", SpeechResultState.SUCCESS)
class MockProvider(BaseProvider, Provider):
"""Mock provider."""
url_path = TEST_DOMAIN
class MockProviderEntity(BaseProvider, SpeechToTextEntity):
"""Mock provider entity."""
url_path = "stt.test"
_attr_name = "test"
@pytest.fixture
def mock_provider() -> MockProvider:
def mock_provider() -> MockSTTProvider:
"""Test provider fixture."""
return MockProvider()
return MockSTTProvider()
@pytest.fixture
def mock_provider_entity() -> MockProviderEntity:
def mock_provider_entity() -> MockSTTProviderEntity:
"""Test provider entity fixture."""
return MockProviderEntity()
return MockSTTProviderEntity()
class STTFlow(ConfigFlow):
@ -148,14 +79,14 @@ async def setup_fixture(
hass: HomeAssistant,
tmp_path: Path,
request: pytest.FixtureRequest,
) -> MockProvider | MockProviderEntity:
) -> MockSTTProvider | MockSTTProviderEntity:
"""Set up the test environment."""
provider: MockProvider | MockProviderEntity
provider: MockSTTProvider | MockSTTProviderEntity
if request.param == "mock_setup":
provider = MockProvider()
provider = MockSTTProvider()
await mock_setup(hass, tmp_path, provider)
elif request.param == "mock_config_entry_setup":
provider = MockProviderEntity()
provider = MockSTTProviderEntity()
await mock_config_entry_setup(hass, tmp_path, provider)
else:
raise RuntimeError("Invalid setup fixture")
@ -166,7 +97,7 @@ async def setup_fixture(
async def mock_setup(
hass: HomeAssistant,
tmp_path: Path,
mock_provider: MockProvider,
mock_provider: MockSTTProvider,
) -> None:
"""Set up a test provider."""
mock_stt_platform(
@ -182,7 +113,7 @@ async def mock_setup(
async def mock_config_entry_setup(
hass: HomeAssistant,
tmp_path: Path,
mock_provider_entity: MockProviderEntity,
mock_provider_entity: MockSTTProviderEntity,
test_domain: str = TEST_DOMAIN,
) -> MockConfigEntry:
"""Set up a test provider via config entry."""
@ -234,7 +165,7 @@ async def mock_config_entry_setup(
async def test_get_provider_info(
hass: HomeAssistant,
hass_client: ClientSessionGenerator,
setup: MockProvider | MockProviderEntity,
setup: MockSTTProvider | MockSTTProviderEntity,
) -> None:
"""Test engine that doesn't exist."""
client = await hass_client()
@ -256,7 +187,7 @@ async def test_get_provider_info(
async def test_non_existing_provider(
hass: HomeAssistant,
hass_client: ClientSessionGenerator,
setup: MockProvider | MockProviderEntity,
setup: MockSTTProvider | MockSTTProviderEntity,
) -> None:
"""Test streaming to engine that doesn't exist."""
client = await hass_client()
@ -282,7 +213,7 @@ async def test_non_existing_provider(
async def test_stream_audio(
hass: HomeAssistant,
hass_client: ClientSessionGenerator,
setup: MockProvider | MockProviderEntity,
setup: MockSTTProvider | MockSTTProviderEntity,
) -> None:
"""Test streaming audio and getting response."""
client = await hass_client()
@ -343,7 +274,7 @@ async def test_metadata_errors(
header: str | None,
status: int,
error: str,
setup: MockProvider | MockProviderEntity,
setup: MockSTTProvider | MockSTTProviderEntity,
) -> None:
"""Test metadata errors."""
client = await hass_client()
@ -359,7 +290,7 @@ async def test_metadata_errors(
async def test_get_provider(
hass: HomeAssistant,
tmp_path: Path,
mock_provider: MockProvider,
mock_provider: MockSTTProvider,
) -> None:
"""Test we can get STT providers."""
await mock_setup(hass, tmp_path, mock_provider)
@ -370,7 +301,7 @@ async def test_get_provider(
async def test_config_entry_unload(
hass: HomeAssistant, tmp_path: Path, mock_provider_entity: MockProviderEntity
hass: HomeAssistant, tmp_path: Path, mock_provider_entity: MockSTTProviderEntity
) -> None:
"""Test we can unload config entry."""
config_entry = await mock_config_entry_setup(hass, tmp_path, mock_provider_entity)
@ -382,7 +313,7 @@ async def test_config_entry_unload(
async def test_restore_state(
hass: HomeAssistant,
tmp_path: Path,
mock_provider_entity: MockProviderEntity,
mock_provider_entity: MockSTTProviderEntity,
) -> None:
"""Test we restore state in the integration."""
entity_id = f"{DOMAIN}.{TEST_DOMAIN}"
@ -409,7 +340,7 @@ async def test_restore_state(
async def test_ws_list_engines(
hass: HomeAssistant,
hass_ws_client: WebSocketGenerator,
setup: MockProvider | MockProviderEntity,
setup: MockSTTProvider | MockSTTProviderEntity,
engine_id: str,
extra_data: dict[str, str],
) -> None:
@ -491,7 +422,7 @@ async def test_default_engine_none(hass: HomeAssistant, tmp_path: Path) -> None:
async def test_default_engine(
hass: HomeAssistant,
tmp_path: Path,
mock_provider: MockProvider,
mock_provider: MockSTTProvider,
) -> None:
"""Test async_default_engine."""
mock_stt_platform(
@ -507,7 +438,7 @@ async def test_default_engine(
async def test_default_engine_entity(
hass: HomeAssistant, tmp_path: Path, mock_provider_entity: MockProviderEntity
hass: HomeAssistant, tmp_path: Path, mock_provider_entity: MockSTTProviderEntity
) -> None:
"""Test async_default_engine."""
await mock_config_entry_setup(hass, tmp_path, mock_provider_entity)
@ -519,8 +450,8 @@ async def test_default_engine_entity(
async def test_default_engine_prefer_entity(
hass: HomeAssistant,
tmp_path: Path,
mock_provider_entity: MockProviderEntity,
mock_provider: MockProvider,
mock_provider_entity: MockSTTProviderEntity,
mock_provider: MockSTTProvider,
config_flow_test_domains: str,
) -> None:
"""Test async_default_engine.
@ -558,7 +489,7 @@ async def test_default_engine_prefer_entity(
async def test_default_engine_prefer_cloud_entity(
hass: HomeAssistant,
tmp_path: Path,
mock_provider: MockProvider,
mock_provider: MockSTTProvider,
config_flow_test_domains: str,
) -> None:
"""Test async_default_engine.
@ -569,7 +500,7 @@ async def test_default_engine_prefer_cloud_entity(
"""
await mock_setup(hass, tmp_path, mock_provider)
for domain in config_flow_test_domains:
entity = MockProviderEntity()
entity = MockSTTProviderEntity()
entity.url_path = f"stt.{domain}"
entity._attr_name = f"{domain} STT entity"
await mock_config_entry_setup(hass, tmp_path, entity, test_domain=domain)
@ -589,7 +520,7 @@ async def test_default_engine_prefer_cloud_entity(
async def test_get_engine_legacy(
hass: HomeAssistant, tmp_path: Path, mock_provider: MockProvider
hass: HomeAssistant, tmp_path: Path, mock_provider: MockSTTProvider
) -> None:
"""Test async_get_speech_to_text_engine."""
mock_stt_platform(
@ -614,7 +545,7 @@ async def test_get_engine_legacy(
async def test_get_engine_entity(
hass: HomeAssistant, tmp_path: Path, mock_provider_entity: MockProviderEntity
hass: HomeAssistant, tmp_path: Path, mock_provider_entity: MockSTTProviderEntity
) -> None:
"""Test async_get_speech_to_text_engine."""
await mock_config_entry_setup(hass, tmp_path, mock_provider_entity)