Deduplicate STT mocks (#124754)

This commit is contained in:
Erik Montnemery 2024-08-28 09:25:56 +02:00 committed by GitHub
parent f9bf7f7e05
commit 1add00a68d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 147 additions and 189 deletions

View File

@ -36,6 +36,7 @@ from tests.common import (
mock_integration, mock_integration,
mock_platform, mock_platform,
) )
from tests.components.stt.common import MockSTTProvider, MockSTTProviderEntity
_TRANSCRIPT = "test transcript" _TRANSCRIPT = "test transcript"
@ -47,67 +48,6 @@ def mock_tts_cache_dir_autouse(mock_tts_cache_dir: Path) -> None:
"""Mock the TTS cache dir with empty dir.""" """Mock the TTS cache dir with empty dir."""
class BaseProvider:
"""Mock STT provider."""
_supported_languages = ["en-US"]
def __init__(self, text: str) -> None:
"""Init test provider."""
self.text = text
self.received: list[bytes] = []
@property
def supported_languages(self) -> list[str]:
"""Return a list of supported languages."""
return self._supported_languages
@property
def supported_formats(self) -> list[stt.AudioFormats]:
"""Return a list of supported formats."""
return [stt.AudioFormats.WAV]
@property
def supported_codecs(self) -> list[stt.AudioCodecs]:
"""Return a list of supported codecs."""
return [stt.AudioCodecs.PCM]
@property
def supported_bit_rates(self) -> list[stt.AudioBitRates]:
"""Return a list of supported bitrates."""
return [stt.AudioBitRates.BITRATE_16]
@property
def supported_sample_rates(self) -> list[stt.AudioSampleRates]:
"""Return a list of supported samplerates."""
return [stt.AudioSampleRates.SAMPLERATE_16000]
@property
def supported_channels(self) -> list[stt.AudioChannels]:
"""Return a list of supported channels."""
return [stt.AudioChannels.CHANNEL_MONO]
async def async_process_audio_stream(
self, metadata: stt.SpeechMetadata, stream: AsyncIterable[bytes]
) -> stt.SpeechResult:
"""Process an audio stream."""
async for data in stream:
if not data:
break
self.received.append(data)
return stt.SpeechResult(self.text, stt.SpeechResultState.SUCCESS)
class MockSttProvider(BaseProvider, stt.Provider):
"""Mock provider."""
class MockSttProviderEntity(BaseProvider, stt.SpeechToTextEntity):
"""Mock provider entity."""
_attr_name = "Mock STT"
class MockTTSProvider(tts.Provider): class MockTTSProvider(tts.Provider):
"""Mock TTS provider.""" """Mock TTS provider."""
@ -166,15 +106,17 @@ async def mock_tts_provider() -> MockTTSProvider:
@pytest.fixture @pytest.fixture
async def mock_stt_provider() -> MockSttProvider: async def mock_stt_provider() -> MockSTTProvider:
"""Mock STT provider.""" """Mock STT provider."""
return MockSttProvider(_TRANSCRIPT) return MockSTTProvider(supported_languages=["en-US"], text=_TRANSCRIPT)
@pytest.fixture @pytest.fixture
def mock_stt_provider_entity() -> MockSttProviderEntity: def mock_stt_provider_entity() -> MockSTTProviderEntity:
"""Test provider entity fixture.""" """Test provider entity fixture."""
return MockSttProviderEntity(_TRANSCRIPT) entity = MockSTTProviderEntity(supported_languages=["en-US"], text=_TRANSCRIPT)
entity._attr_name = "Mock STT"
return entity
class MockSttPlatform(MockPlatform): class MockSttPlatform(MockPlatform):
@ -290,8 +232,8 @@ def config_flow_fixture(hass: HomeAssistant) -> Generator[None]:
@pytest.fixture @pytest.fixture
async def init_supporting_components( async def init_supporting_components(
hass: HomeAssistant, hass: HomeAssistant,
mock_stt_provider: MockSttProvider, mock_stt_provider: MockSTTProvider,
mock_stt_provider_entity: MockSttProviderEntity, mock_stt_provider_entity: MockSTTProviderEntity,
mock_tts_provider: MockTTSProvider, mock_tts_provider: MockTTSProvider,
mock_wake_word_provider_entity: MockWakeWordEntity, mock_wake_word_provider_entity: MockWakeWordEntity,
mock_wake_word_provider_entity2: MockWakeWordEntity2, mock_wake_word_provider_entity2: MockWakeWordEntity2,

View File

@ -22,8 +22,8 @@ from homeassistant.setup import async_setup_component
from .conftest import ( from .conftest import (
BYTES_ONE_SECOND, BYTES_ONE_SECOND,
MockSttProvider, MockSTTProvider,
MockSttProviderEntity, MockSTTProviderEntity,
MockTTSProvider, MockTTSProvider,
MockWakeWordEntity, MockWakeWordEntity,
make_10ms_chunk, make_10ms_chunk,
@ -47,7 +47,7 @@ def process_events(events: list[assist_pipeline.PipelineEvent]) -> list[dict]:
async def test_pipeline_from_audio_stream_auto( async def test_pipeline_from_audio_stream_auto(
hass: HomeAssistant, hass: HomeAssistant,
mock_stt_provider_entity: MockSttProviderEntity, mock_stt_provider_entity: MockSTTProviderEntity,
init_components, init_components,
snapshot: SnapshotAssertion, snapshot: SnapshotAssertion,
) -> None: ) -> None:
@ -88,7 +88,7 @@ async def test_pipeline_from_audio_stream_auto(
async def test_pipeline_from_audio_stream_legacy( async def test_pipeline_from_audio_stream_legacy(
hass: HomeAssistant, hass: HomeAssistant,
hass_ws_client: WebSocketGenerator, hass_ws_client: WebSocketGenerator,
mock_stt_provider: MockSttProvider, mock_stt_provider: MockSTTProvider,
init_components, init_components,
snapshot: SnapshotAssertion, snapshot: SnapshotAssertion,
) -> None: ) -> None:
@ -153,7 +153,7 @@ async def test_pipeline_from_audio_stream_legacy(
async def test_pipeline_from_audio_stream_entity( async def test_pipeline_from_audio_stream_entity(
hass: HomeAssistant, hass: HomeAssistant,
hass_ws_client: WebSocketGenerator, hass_ws_client: WebSocketGenerator,
mock_stt_provider_entity: MockSttProviderEntity, mock_stt_provider_entity: MockSTTProviderEntity,
init_components, init_components,
snapshot: SnapshotAssertion, snapshot: SnapshotAssertion,
) -> None: ) -> None:
@ -218,7 +218,7 @@ async def test_pipeline_from_audio_stream_entity(
async def test_pipeline_from_audio_stream_no_stt( async def test_pipeline_from_audio_stream_no_stt(
hass: HomeAssistant, hass: HomeAssistant,
hass_ws_client: WebSocketGenerator, hass_ws_client: WebSocketGenerator,
mock_stt_provider: MockSttProvider, mock_stt_provider: MockSTTProvider,
init_components, init_components,
snapshot: SnapshotAssertion, snapshot: SnapshotAssertion,
) -> None: ) -> None:
@ -281,7 +281,7 @@ async def test_pipeline_from_audio_stream_no_stt(
async def test_pipeline_from_audio_stream_unknown_pipeline( async def test_pipeline_from_audio_stream_unknown_pipeline(
hass: HomeAssistant, hass: HomeAssistant,
hass_ws_client: WebSocketGenerator, hass_ws_client: WebSocketGenerator,
mock_stt_provider: MockSttProvider, mock_stt_provider: MockSTTProvider,
init_components, init_components,
snapshot: SnapshotAssertion, snapshot: SnapshotAssertion,
) -> None: ) -> None:
@ -319,7 +319,7 @@ async def test_pipeline_from_audio_stream_unknown_pipeline(
async def test_pipeline_from_audio_stream_wake_word( async def test_pipeline_from_audio_stream_wake_word(
hass: HomeAssistant, hass: HomeAssistant,
mock_stt_provider_entity: MockSttProviderEntity, mock_stt_provider_entity: MockSTTProviderEntity,
mock_wake_word_provider_entity: MockWakeWordEntity, mock_wake_word_provider_entity: MockWakeWordEntity,
init_components, init_components,
snapshot: SnapshotAssertion, snapshot: SnapshotAssertion,
@ -395,7 +395,7 @@ async def test_pipeline_from_audio_stream_wake_word(
async def test_pipeline_save_audio( async def test_pipeline_save_audio(
hass: HomeAssistant, hass: HomeAssistant,
mock_stt_provider: MockSttProvider, mock_stt_provider: MockSTTProvider,
mock_wake_word_provider_entity: MockWakeWordEntity, mock_wake_word_provider_entity: MockWakeWordEntity,
init_supporting_components, init_supporting_components,
snapshot: SnapshotAssertion, snapshot: SnapshotAssertion,
@ -474,7 +474,7 @@ async def test_pipeline_save_audio(
async def test_pipeline_saved_audio_with_device_id( async def test_pipeline_saved_audio_with_device_id(
hass: HomeAssistant, hass: HomeAssistant,
mock_stt_provider: MockSttProvider, mock_stt_provider: MockSTTProvider,
mock_wake_word_provider_entity: MockWakeWordEntity, mock_wake_word_provider_entity: MockWakeWordEntity,
init_supporting_components, init_supporting_components,
snapshot: SnapshotAssertion, snapshot: SnapshotAssertion,
@ -529,7 +529,7 @@ async def test_pipeline_saved_audio_with_device_id(
async def test_pipeline_saved_audio_write_error( async def test_pipeline_saved_audio_write_error(
hass: HomeAssistant, hass: HomeAssistant,
mock_stt_provider: MockSttProvider, mock_stt_provider: MockSTTProvider,
mock_wake_word_provider_entity: MockWakeWordEntity, mock_wake_word_provider_entity: MockWakeWordEntity,
init_supporting_components, init_supporting_components,
snapshot: SnapshotAssertion, snapshot: SnapshotAssertion,
@ -578,7 +578,7 @@ async def test_pipeline_saved_audio_write_error(
async def test_pipeline_saved_audio_empty_queue( async def test_pipeline_saved_audio_empty_queue(
hass: HomeAssistant, hass: HomeAssistant,
mock_stt_provider: MockSttProvider, mock_stt_provider: MockSTTProvider,
mock_wake_word_provider_entity: MockWakeWordEntity, mock_wake_word_provider_entity: MockWakeWordEntity,
init_supporting_components, init_supporting_components,
snapshot: SnapshotAssertion, snapshot: SnapshotAssertion,
@ -641,7 +641,7 @@ async def test_pipeline_saved_audio_empty_queue(
async def test_wake_word_detection_aborted( async def test_wake_word_detection_aborted(
hass: HomeAssistant, hass: HomeAssistant,
mock_stt_provider: MockSttProvider, mock_stt_provider: MockSTTProvider,
mock_wake_word_provider_entity: MockWakeWordEntity, mock_wake_word_provider_entity: MockWakeWordEntity,
init_components, init_components,
pipeline_data: assist_pipeline.pipeline.PipelineData, pipeline_data: assist_pipeline.pipeline.PipelineData,

View File

@ -26,7 +26,7 @@ from homeassistant.core import HomeAssistant
from homeassistant.setup import async_setup_component from homeassistant.setup import async_setup_component
from . import MANY_LANGUAGES from . import MANY_LANGUAGES
from .conftest import MockSttProviderEntity, MockTTSProvider from .conftest import MockSTTProviderEntity, MockTTSProvider
from tests.common import flush_store from tests.common import flush_store
@ -398,7 +398,7 @@ async def test_default_pipeline_no_stt_tts(
@pytest.mark.usefixtures("init_supporting_components") @pytest.mark.usefixtures("init_supporting_components")
async def test_default_pipeline( async def test_default_pipeline(
hass: HomeAssistant, hass: HomeAssistant,
mock_stt_provider_entity: MockSttProviderEntity, mock_stt_provider_entity: MockSTTProviderEntity,
mock_tts_provider: MockTTSProvider, mock_tts_provider: MockTTSProvider,
ha_language: str, ha_language: str,
ha_country: str | None, ha_country: str | None,
@ -441,7 +441,7 @@ async def test_default_pipeline(
@pytest.mark.usefixtures("init_supporting_components") @pytest.mark.usefixtures("init_supporting_components")
async def test_default_pipeline_unsupported_stt_language( async def test_default_pipeline_unsupported_stt_language(
hass: HomeAssistant, mock_stt_provider_entity: MockSttProviderEntity hass: HomeAssistant, mock_stt_provider_entity: MockSTTProviderEntity
) -> None: ) -> None:
"""Test async_get_pipeline.""" """Test async_get_pipeline."""
with patch.object(mock_stt_provider_entity, "_supported_languages", ["smurfish"]): with patch.object(mock_stt_provider_entity, "_supported_languages", ["smurfish"]):

View File

@ -743,7 +743,7 @@ async def test_stt_stream_failed(
client = await hass_ws_client(hass) client = await hass_ws_client(hass)
with patch( with patch(
"tests.components.assist_pipeline.conftest.MockSttProviderEntity.async_process_audio_stream", "tests.components.assist_pipeline.conftest.MockSTTProviderEntity.async_process_audio_stream",
side_effect=RuntimeError, side_effect=RuntimeError,
): ):
await client.send_json_auto_id( await client.send_json_auto_id(

View File

@ -2,11 +2,22 @@
from __future__ import annotations from __future__ import annotations
from collections.abc import Callable, Coroutine from collections.abc import AsyncIterable, Callable, Coroutine
from pathlib import Path from pathlib import Path
from typing import Any from typing import Any
from homeassistant.components.stt import Provider from homeassistant.components.stt import (
AudioBitRates,
AudioChannels,
AudioCodecs,
AudioFormats,
AudioSampleRates,
Provider,
SpeechMetadata,
SpeechResult,
SpeechResultState,
SpeechToTextEntity,
)
from homeassistant.config_entries import ConfigEntry from homeassistant.config_entries import ConfigEntry
from homeassistant.core import HomeAssistant from homeassistant.core import HomeAssistant
from homeassistant.helpers.entity_platform import AddEntitiesCallback from homeassistant.helpers.entity_platform import AddEntitiesCallback
@ -14,6 +25,80 @@ from homeassistant.helpers.typing import ConfigType, DiscoveryInfoType
from tests.common import MockPlatform, mock_platform from tests.common import MockPlatform, mock_platform
TEST_DOMAIN = "test"
class BaseProvider:
"""Mock STT provider."""
fail_process_audio = False
def __init__(
self, *, supported_languages: list[str] | None = None, text: str = "test_result"
) -> None:
"""Init test provider."""
self._supported_languages = supported_languages or ["de", "de-CH", "en"]
self.calls: list[tuple[SpeechMetadata, AsyncIterable[bytes]]] = []
self.received: list[bytes] = []
self.text = text
@property
def supported_languages(self) -> list[str]:
"""Return a list of supported languages."""
return self._supported_languages
@property
def supported_formats(self) -> list[AudioFormats]:
"""Return a list of supported formats."""
return [AudioFormats.WAV, AudioFormats.OGG]
@property
def supported_codecs(self) -> list[AudioCodecs]:
"""Return a list of supported codecs."""
return [AudioCodecs.PCM, AudioCodecs.OPUS]
@property
def supported_bit_rates(self) -> list[AudioBitRates]:
"""Return a list of supported bitrates."""
return [AudioBitRates.BITRATE_16]
@property
def supported_sample_rates(self) -> list[AudioSampleRates]:
"""Return a list of supported samplerates."""
return [AudioSampleRates.SAMPLERATE_16000]
@property
def supported_channels(self) -> list[AudioChannels]:
"""Return a list of supported channels."""
return [AudioChannels.CHANNEL_MONO]
async def async_process_audio_stream(
self, metadata: SpeechMetadata, stream: AsyncIterable[bytes]
) -> SpeechResult:
"""Process an audio stream."""
self.calls.append((metadata, stream))
async for data in stream:
if not data:
break
self.received.append(data)
if self.fail_process_audio:
return SpeechResult(None, SpeechResultState.ERROR)
return SpeechResult(self.text, SpeechResultState.SUCCESS)
class MockSTTProvider(BaseProvider, Provider):
"""Mock provider."""
url_path = TEST_DOMAIN
class MockSTTProviderEntity(BaseProvider, SpeechToTextEntity):
"""Mock provider entity."""
url_path = "stt.test"
_attr_name = "test"
class MockSTTPlatform(MockPlatform): class MockSTTPlatform(MockPlatform):
"""Help to set up test stt service.""" """Help to set up test stt service."""

View File

@ -1,6 +1,6 @@
"""Test STT component setup.""" """Test STT component setup."""
from collections.abc import AsyncIterable, Generator, Iterable from collections.abc import Generator, Iterable
from contextlib import ExitStack from contextlib import ExitStack
from http import HTTPStatus from http import HTTPStatus
from pathlib import Path from pathlib import Path
@ -10,16 +10,6 @@ import pytest
from homeassistant.components.stt import ( from homeassistant.components.stt import (
DOMAIN, DOMAIN,
AudioBitRates,
AudioChannels,
AudioCodecs,
AudioFormats,
AudioSampleRates,
Provider,
SpeechMetadata,
SpeechResult,
SpeechResultState,
SpeechToTextEntity,
async_default_engine, async_default_engine,
async_get_provider, async_get_provider,
async_get_speech_to_text_engine, async_get_speech_to_text_engine,
@ -29,7 +19,13 @@ from homeassistant.core import HomeAssistant, State
from homeassistant.helpers.entity_platform import AddEntitiesCallback from homeassistant.helpers.entity_platform import AddEntitiesCallback
from homeassistant.setup import async_setup_component from homeassistant.setup import async_setup_component
from .common import mock_stt_entity_platform, mock_stt_platform from .common import (
TEST_DOMAIN,
MockSTTProvider,
MockSTTProviderEntity,
mock_stt_entity_platform,
mock_stt_platform,
)
from tests.common import ( from tests.common import (
MockConfigEntry, MockConfigEntry,
@ -41,82 +37,17 @@ from tests.common import (
) )
from tests.typing import ClientSessionGenerator, WebSocketGenerator from tests.typing import ClientSessionGenerator, WebSocketGenerator
TEST_DOMAIN = "test"
class BaseProvider:
"""Mock provider."""
fail_process_audio = False
def __init__(self) -> None:
"""Init test provider."""
self.calls: list[tuple[SpeechMetadata, AsyncIterable[bytes]]] = []
@property
def supported_languages(self) -> list[str]:
"""Return a list of supported languages."""
return ["de", "de-CH", "en"]
@property
def supported_formats(self) -> list[AudioFormats]:
"""Return a list of supported formats."""
return [AudioFormats.WAV, AudioFormats.OGG]
@property
def supported_codecs(self) -> list[AudioCodecs]:
"""Return a list of supported codecs."""
return [AudioCodecs.PCM, AudioCodecs.OPUS]
@property
def supported_bit_rates(self) -> list[AudioBitRates]:
"""Return a list of supported bitrates."""
return [AudioBitRates.BITRATE_16]
@property
def supported_sample_rates(self) -> list[AudioSampleRates]:
"""Return a list of supported samplerates."""
return [AudioSampleRates.SAMPLERATE_16000]
@property
def supported_channels(self) -> list[AudioChannels]:
"""Return a list of supported channels."""
return [AudioChannels.CHANNEL_MONO]
async def async_process_audio_stream(
self, metadata: SpeechMetadata, stream: AsyncIterable[bytes]
) -> SpeechResult:
"""Process an audio stream."""
self.calls.append((metadata, stream))
if self.fail_process_audio:
return SpeechResult(None, SpeechResultState.ERROR)
return SpeechResult("test_result", SpeechResultState.SUCCESS)
class MockProvider(BaseProvider, Provider):
"""Mock provider."""
url_path = TEST_DOMAIN
class MockProviderEntity(BaseProvider, SpeechToTextEntity):
"""Mock provider entity."""
url_path = "stt.test"
_attr_name = "test"
@pytest.fixture @pytest.fixture
def mock_provider() -> MockProvider: def mock_provider() -> MockSTTProvider:
"""Test provider fixture.""" """Test provider fixture."""
return MockProvider() return MockSTTProvider()
@pytest.fixture @pytest.fixture
def mock_provider_entity() -> MockProviderEntity: def mock_provider_entity() -> MockSTTProviderEntity:
"""Test provider entity fixture.""" """Test provider entity fixture."""
return MockProviderEntity() return MockSTTProviderEntity()
class STTFlow(ConfigFlow): class STTFlow(ConfigFlow):
@ -148,14 +79,14 @@ async def setup_fixture(
hass: HomeAssistant, hass: HomeAssistant,
tmp_path: Path, tmp_path: Path,
request: pytest.FixtureRequest, request: pytest.FixtureRequest,
) -> MockProvider | MockProviderEntity: ) -> MockSTTProvider | MockSTTProviderEntity:
"""Set up the test environment.""" """Set up the test environment."""
provider: MockProvider | MockProviderEntity provider: MockSTTProvider | MockSTTProviderEntity
if request.param == "mock_setup": if request.param == "mock_setup":
provider = MockProvider() provider = MockSTTProvider()
await mock_setup(hass, tmp_path, provider) await mock_setup(hass, tmp_path, provider)
elif request.param == "mock_config_entry_setup": elif request.param == "mock_config_entry_setup":
provider = MockProviderEntity() provider = MockSTTProviderEntity()
await mock_config_entry_setup(hass, tmp_path, provider) await mock_config_entry_setup(hass, tmp_path, provider)
else: else:
raise RuntimeError("Invalid setup fixture") raise RuntimeError("Invalid setup fixture")
@ -166,7 +97,7 @@ async def setup_fixture(
async def mock_setup( async def mock_setup(
hass: HomeAssistant, hass: HomeAssistant,
tmp_path: Path, tmp_path: Path,
mock_provider: MockProvider, mock_provider: MockSTTProvider,
) -> None: ) -> None:
"""Set up a test provider.""" """Set up a test provider."""
mock_stt_platform( mock_stt_platform(
@ -182,7 +113,7 @@ async def mock_setup(
async def mock_config_entry_setup( async def mock_config_entry_setup(
hass: HomeAssistant, hass: HomeAssistant,
tmp_path: Path, tmp_path: Path,
mock_provider_entity: MockProviderEntity, mock_provider_entity: MockSTTProviderEntity,
test_domain: str = TEST_DOMAIN, test_domain: str = TEST_DOMAIN,
) -> MockConfigEntry: ) -> MockConfigEntry:
"""Set up a test provider via config entry.""" """Set up a test provider via config entry."""
@ -234,7 +165,7 @@ async def mock_config_entry_setup(
async def test_get_provider_info( async def test_get_provider_info(
hass: HomeAssistant, hass: HomeAssistant,
hass_client: ClientSessionGenerator, hass_client: ClientSessionGenerator,
setup: MockProvider | MockProviderEntity, setup: MockSTTProvider | MockSTTProviderEntity,
) -> None: ) -> None:
"""Test engine that doesn't exist.""" """Test engine that doesn't exist."""
client = await hass_client() client = await hass_client()
@ -256,7 +187,7 @@ async def test_get_provider_info(
async def test_non_existing_provider( async def test_non_existing_provider(
hass: HomeAssistant, hass: HomeAssistant,
hass_client: ClientSessionGenerator, hass_client: ClientSessionGenerator,
setup: MockProvider | MockProviderEntity, setup: MockSTTProvider | MockSTTProviderEntity,
) -> None: ) -> None:
"""Test streaming to engine that doesn't exist.""" """Test streaming to engine that doesn't exist."""
client = await hass_client() client = await hass_client()
@ -282,7 +213,7 @@ async def test_non_existing_provider(
async def test_stream_audio( async def test_stream_audio(
hass: HomeAssistant, hass: HomeAssistant,
hass_client: ClientSessionGenerator, hass_client: ClientSessionGenerator,
setup: MockProvider | MockProviderEntity, setup: MockSTTProvider | MockSTTProviderEntity,
) -> None: ) -> None:
"""Test streaming audio and getting response.""" """Test streaming audio and getting response."""
client = await hass_client() client = await hass_client()
@ -343,7 +274,7 @@ async def test_metadata_errors(
header: str | None, header: str | None,
status: int, status: int,
error: str, error: str,
setup: MockProvider | MockProviderEntity, setup: MockSTTProvider | MockSTTProviderEntity,
) -> None: ) -> None:
"""Test metadata errors.""" """Test metadata errors."""
client = await hass_client() client = await hass_client()
@ -359,7 +290,7 @@ async def test_metadata_errors(
async def test_get_provider( async def test_get_provider(
hass: HomeAssistant, hass: HomeAssistant,
tmp_path: Path, tmp_path: Path,
mock_provider: MockProvider, mock_provider: MockSTTProvider,
) -> None: ) -> None:
"""Test we can get STT providers.""" """Test we can get STT providers."""
await mock_setup(hass, tmp_path, mock_provider) await mock_setup(hass, tmp_path, mock_provider)
@ -370,7 +301,7 @@ async def test_get_provider(
async def test_config_entry_unload( async def test_config_entry_unload(
hass: HomeAssistant, tmp_path: Path, mock_provider_entity: MockProviderEntity hass: HomeAssistant, tmp_path: Path, mock_provider_entity: MockSTTProviderEntity
) -> None: ) -> None:
"""Test we can unload config entry.""" """Test we can unload config entry."""
config_entry = await mock_config_entry_setup(hass, tmp_path, mock_provider_entity) config_entry = await mock_config_entry_setup(hass, tmp_path, mock_provider_entity)
@ -382,7 +313,7 @@ async def test_config_entry_unload(
async def test_restore_state( async def test_restore_state(
hass: HomeAssistant, hass: HomeAssistant,
tmp_path: Path, tmp_path: Path,
mock_provider_entity: MockProviderEntity, mock_provider_entity: MockSTTProviderEntity,
) -> None: ) -> None:
"""Test we restore state in the integration.""" """Test we restore state in the integration."""
entity_id = f"{DOMAIN}.{TEST_DOMAIN}" entity_id = f"{DOMAIN}.{TEST_DOMAIN}"
@ -409,7 +340,7 @@ async def test_restore_state(
async def test_ws_list_engines( async def test_ws_list_engines(
hass: HomeAssistant, hass: HomeAssistant,
hass_ws_client: WebSocketGenerator, hass_ws_client: WebSocketGenerator,
setup: MockProvider | MockProviderEntity, setup: MockSTTProvider | MockSTTProviderEntity,
engine_id: str, engine_id: str,
extra_data: dict[str, str], extra_data: dict[str, str],
) -> None: ) -> None:
@ -491,7 +422,7 @@ async def test_default_engine_none(hass: HomeAssistant, tmp_path: Path) -> None:
async def test_default_engine( async def test_default_engine(
hass: HomeAssistant, hass: HomeAssistant,
tmp_path: Path, tmp_path: Path,
mock_provider: MockProvider, mock_provider: MockSTTProvider,
) -> None: ) -> None:
"""Test async_default_engine.""" """Test async_default_engine."""
mock_stt_platform( mock_stt_platform(
@ -507,7 +438,7 @@ async def test_default_engine(
async def test_default_engine_entity( async def test_default_engine_entity(
hass: HomeAssistant, tmp_path: Path, mock_provider_entity: MockProviderEntity hass: HomeAssistant, tmp_path: Path, mock_provider_entity: MockSTTProviderEntity
) -> None: ) -> None:
"""Test async_default_engine.""" """Test async_default_engine."""
await mock_config_entry_setup(hass, tmp_path, mock_provider_entity) await mock_config_entry_setup(hass, tmp_path, mock_provider_entity)
@ -519,8 +450,8 @@ async def test_default_engine_entity(
async def test_default_engine_prefer_entity( async def test_default_engine_prefer_entity(
hass: HomeAssistant, hass: HomeAssistant,
tmp_path: Path, tmp_path: Path,
mock_provider_entity: MockProviderEntity, mock_provider_entity: MockSTTProviderEntity,
mock_provider: MockProvider, mock_provider: MockSTTProvider,
config_flow_test_domains: str, config_flow_test_domains: str,
) -> None: ) -> None:
"""Test async_default_engine. """Test async_default_engine.
@ -558,7 +489,7 @@ async def test_default_engine_prefer_entity(
async def test_default_engine_prefer_cloud_entity( async def test_default_engine_prefer_cloud_entity(
hass: HomeAssistant, hass: HomeAssistant,
tmp_path: Path, tmp_path: Path,
mock_provider: MockProvider, mock_provider: MockSTTProvider,
config_flow_test_domains: str, config_flow_test_domains: str,
) -> None: ) -> None:
"""Test async_default_engine. """Test async_default_engine.
@ -569,7 +500,7 @@ async def test_default_engine_prefer_cloud_entity(
""" """
await mock_setup(hass, tmp_path, mock_provider) await mock_setup(hass, tmp_path, mock_provider)
for domain in config_flow_test_domains: for domain in config_flow_test_domains:
entity = MockProviderEntity() entity = MockSTTProviderEntity()
entity.url_path = f"stt.{domain}" entity.url_path = f"stt.{domain}"
entity._attr_name = f"{domain} STT entity" entity._attr_name = f"{domain} STT entity"
await mock_config_entry_setup(hass, tmp_path, entity, test_domain=domain) await mock_config_entry_setup(hass, tmp_path, entity, test_domain=domain)
@ -589,7 +520,7 @@ async def test_default_engine_prefer_cloud_entity(
async def test_get_engine_legacy( async def test_get_engine_legacy(
hass: HomeAssistant, tmp_path: Path, mock_provider: MockProvider hass: HomeAssistant, tmp_path: Path, mock_provider: MockSTTProvider
) -> None: ) -> None:
"""Test async_get_speech_to_text_engine.""" """Test async_get_speech_to_text_engine."""
mock_stt_platform( mock_stt_platform(
@ -614,7 +545,7 @@ async def test_get_engine_legacy(
async def test_get_engine_entity( async def test_get_engine_entity(
hass: HomeAssistant, tmp_path: Path, mock_provider_entity: MockProviderEntity hass: HomeAssistant, tmp_path: Path, mock_provider_entity: MockSTTProviderEntity
) -> None: ) -> None:
"""Test async_get_speech_to_text_engine.""" """Test async_get_speech_to_text_engine."""
await mock_config_entry_setup(hass, tmp_path, mock_provider_entity) await mock_config_entry_setup(hass, tmp_path, mock_provider_entity)