mirror of
https://github.com/home-assistant/core.git
synced 2025-07-21 12:17:07 +00:00
Deduplicate STT mocks (#124754)
This commit is contained in:
parent
f9bf7f7e05
commit
1add00a68d
@ -36,6 +36,7 @@ from tests.common import (
|
|||||||
mock_integration,
|
mock_integration,
|
||||||
mock_platform,
|
mock_platform,
|
||||||
)
|
)
|
||||||
|
from tests.components.stt.common import MockSTTProvider, MockSTTProviderEntity
|
||||||
|
|
||||||
_TRANSCRIPT = "test transcript"
|
_TRANSCRIPT = "test transcript"
|
||||||
|
|
||||||
@ -47,67 +48,6 @@ def mock_tts_cache_dir_autouse(mock_tts_cache_dir: Path) -> None:
|
|||||||
"""Mock the TTS cache dir with empty dir."""
|
"""Mock the TTS cache dir with empty dir."""
|
||||||
|
|
||||||
|
|
||||||
class BaseProvider:
|
|
||||||
"""Mock STT provider."""
|
|
||||||
|
|
||||||
_supported_languages = ["en-US"]
|
|
||||||
|
|
||||||
def __init__(self, text: str) -> None:
|
|
||||||
"""Init test provider."""
|
|
||||||
self.text = text
|
|
||||||
self.received: list[bytes] = []
|
|
||||||
|
|
||||||
@property
|
|
||||||
def supported_languages(self) -> list[str]:
|
|
||||||
"""Return a list of supported languages."""
|
|
||||||
return self._supported_languages
|
|
||||||
|
|
||||||
@property
|
|
||||||
def supported_formats(self) -> list[stt.AudioFormats]:
|
|
||||||
"""Return a list of supported formats."""
|
|
||||||
return [stt.AudioFormats.WAV]
|
|
||||||
|
|
||||||
@property
|
|
||||||
def supported_codecs(self) -> list[stt.AudioCodecs]:
|
|
||||||
"""Return a list of supported codecs."""
|
|
||||||
return [stt.AudioCodecs.PCM]
|
|
||||||
|
|
||||||
@property
|
|
||||||
def supported_bit_rates(self) -> list[stt.AudioBitRates]:
|
|
||||||
"""Return a list of supported bitrates."""
|
|
||||||
return [stt.AudioBitRates.BITRATE_16]
|
|
||||||
|
|
||||||
@property
|
|
||||||
def supported_sample_rates(self) -> list[stt.AudioSampleRates]:
|
|
||||||
"""Return a list of supported samplerates."""
|
|
||||||
return [stt.AudioSampleRates.SAMPLERATE_16000]
|
|
||||||
|
|
||||||
@property
|
|
||||||
def supported_channels(self) -> list[stt.AudioChannels]:
|
|
||||||
"""Return a list of supported channels."""
|
|
||||||
return [stt.AudioChannels.CHANNEL_MONO]
|
|
||||||
|
|
||||||
async def async_process_audio_stream(
|
|
||||||
self, metadata: stt.SpeechMetadata, stream: AsyncIterable[bytes]
|
|
||||||
) -> stt.SpeechResult:
|
|
||||||
"""Process an audio stream."""
|
|
||||||
async for data in stream:
|
|
||||||
if not data:
|
|
||||||
break
|
|
||||||
self.received.append(data)
|
|
||||||
return stt.SpeechResult(self.text, stt.SpeechResultState.SUCCESS)
|
|
||||||
|
|
||||||
|
|
||||||
class MockSttProvider(BaseProvider, stt.Provider):
|
|
||||||
"""Mock provider."""
|
|
||||||
|
|
||||||
|
|
||||||
class MockSttProviderEntity(BaseProvider, stt.SpeechToTextEntity):
|
|
||||||
"""Mock provider entity."""
|
|
||||||
|
|
||||||
_attr_name = "Mock STT"
|
|
||||||
|
|
||||||
|
|
||||||
class MockTTSProvider(tts.Provider):
|
class MockTTSProvider(tts.Provider):
|
||||||
"""Mock TTS provider."""
|
"""Mock TTS provider."""
|
||||||
|
|
||||||
@ -166,15 +106,17 @@ async def mock_tts_provider() -> MockTTSProvider:
|
|||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
async def mock_stt_provider() -> MockSttProvider:
|
async def mock_stt_provider() -> MockSTTProvider:
|
||||||
"""Mock STT provider."""
|
"""Mock STT provider."""
|
||||||
return MockSttProvider(_TRANSCRIPT)
|
return MockSTTProvider(supported_languages=["en-US"], text=_TRANSCRIPT)
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def mock_stt_provider_entity() -> MockSttProviderEntity:
|
def mock_stt_provider_entity() -> MockSTTProviderEntity:
|
||||||
"""Test provider entity fixture."""
|
"""Test provider entity fixture."""
|
||||||
return MockSttProviderEntity(_TRANSCRIPT)
|
entity = MockSTTProviderEntity(supported_languages=["en-US"], text=_TRANSCRIPT)
|
||||||
|
entity._attr_name = "Mock STT"
|
||||||
|
return entity
|
||||||
|
|
||||||
|
|
||||||
class MockSttPlatform(MockPlatform):
|
class MockSttPlatform(MockPlatform):
|
||||||
@ -290,8 +232,8 @@ def config_flow_fixture(hass: HomeAssistant) -> Generator[None]:
|
|||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
async def init_supporting_components(
|
async def init_supporting_components(
|
||||||
hass: HomeAssistant,
|
hass: HomeAssistant,
|
||||||
mock_stt_provider: MockSttProvider,
|
mock_stt_provider: MockSTTProvider,
|
||||||
mock_stt_provider_entity: MockSttProviderEntity,
|
mock_stt_provider_entity: MockSTTProviderEntity,
|
||||||
mock_tts_provider: MockTTSProvider,
|
mock_tts_provider: MockTTSProvider,
|
||||||
mock_wake_word_provider_entity: MockWakeWordEntity,
|
mock_wake_word_provider_entity: MockWakeWordEntity,
|
||||||
mock_wake_word_provider_entity2: MockWakeWordEntity2,
|
mock_wake_word_provider_entity2: MockWakeWordEntity2,
|
||||||
|
@ -22,8 +22,8 @@ from homeassistant.setup import async_setup_component
|
|||||||
|
|
||||||
from .conftest import (
|
from .conftest import (
|
||||||
BYTES_ONE_SECOND,
|
BYTES_ONE_SECOND,
|
||||||
MockSttProvider,
|
MockSTTProvider,
|
||||||
MockSttProviderEntity,
|
MockSTTProviderEntity,
|
||||||
MockTTSProvider,
|
MockTTSProvider,
|
||||||
MockWakeWordEntity,
|
MockWakeWordEntity,
|
||||||
make_10ms_chunk,
|
make_10ms_chunk,
|
||||||
@ -47,7 +47,7 @@ def process_events(events: list[assist_pipeline.PipelineEvent]) -> list[dict]:
|
|||||||
|
|
||||||
async def test_pipeline_from_audio_stream_auto(
|
async def test_pipeline_from_audio_stream_auto(
|
||||||
hass: HomeAssistant,
|
hass: HomeAssistant,
|
||||||
mock_stt_provider_entity: MockSttProviderEntity,
|
mock_stt_provider_entity: MockSTTProviderEntity,
|
||||||
init_components,
|
init_components,
|
||||||
snapshot: SnapshotAssertion,
|
snapshot: SnapshotAssertion,
|
||||||
) -> None:
|
) -> None:
|
||||||
@ -88,7 +88,7 @@ async def test_pipeline_from_audio_stream_auto(
|
|||||||
async def test_pipeline_from_audio_stream_legacy(
|
async def test_pipeline_from_audio_stream_legacy(
|
||||||
hass: HomeAssistant,
|
hass: HomeAssistant,
|
||||||
hass_ws_client: WebSocketGenerator,
|
hass_ws_client: WebSocketGenerator,
|
||||||
mock_stt_provider: MockSttProvider,
|
mock_stt_provider: MockSTTProvider,
|
||||||
init_components,
|
init_components,
|
||||||
snapshot: SnapshotAssertion,
|
snapshot: SnapshotAssertion,
|
||||||
) -> None:
|
) -> None:
|
||||||
@ -153,7 +153,7 @@ async def test_pipeline_from_audio_stream_legacy(
|
|||||||
async def test_pipeline_from_audio_stream_entity(
|
async def test_pipeline_from_audio_stream_entity(
|
||||||
hass: HomeAssistant,
|
hass: HomeAssistant,
|
||||||
hass_ws_client: WebSocketGenerator,
|
hass_ws_client: WebSocketGenerator,
|
||||||
mock_stt_provider_entity: MockSttProviderEntity,
|
mock_stt_provider_entity: MockSTTProviderEntity,
|
||||||
init_components,
|
init_components,
|
||||||
snapshot: SnapshotAssertion,
|
snapshot: SnapshotAssertion,
|
||||||
) -> None:
|
) -> None:
|
||||||
@ -218,7 +218,7 @@ async def test_pipeline_from_audio_stream_entity(
|
|||||||
async def test_pipeline_from_audio_stream_no_stt(
|
async def test_pipeline_from_audio_stream_no_stt(
|
||||||
hass: HomeAssistant,
|
hass: HomeAssistant,
|
||||||
hass_ws_client: WebSocketGenerator,
|
hass_ws_client: WebSocketGenerator,
|
||||||
mock_stt_provider: MockSttProvider,
|
mock_stt_provider: MockSTTProvider,
|
||||||
init_components,
|
init_components,
|
||||||
snapshot: SnapshotAssertion,
|
snapshot: SnapshotAssertion,
|
||||||
) -> None:
|
) -> None:
|
||||||
@ -281,7 +281,7 @@ async def test_pipeline_from_audio_stream_no_stt(
|
|||||||
async def test_pipeline_from_audio_stream_unknown_pipeline(
|
async def test_pipeline_from_audio_stream_unknown_pipeline(
|
||||||
hass: HomeAssistant,
|
hass: HomeAssistant,
|
||||||
hass_ws_client: WebSocketGenerator,
|
hass_ws_client: WebSocketGenerator,
|
||||||
mock_stt_provider: MockSttProvider,
|
mock_stt_provider: MockSTTProvider,
|
||||||
init_components,
|
init_components,
|
||||||
snapshot: SnapshotAssertion,
|
snapshot: SnapshotAssertion,
|
||||||
) -> None:
|
) -> None:
|
||||||
@ -319,7 +319,7 @@ async def test_pipeline_from_audio_stream_unknown_pipeline(
|
|||||||
|
|
||||||
async def test_pipeline_from_audio_stream_wake_word(
|
async def test_pipeline_from_audio_stream_wake_word(
|
||||||
hass: HomeAssistant,
|
hass: HomeAssistant,
|
||||||
mock_stt_provider_entity: MockSttProviderEntity,
|
mock_stt_provider_entity: MockSTTProviderEntity,
|
||||||
mock_wake_word_provider_entity: MockWakeWordEntity,
|
mock_wake_word_provider_entity: MockWakeWordEntity,
|
||||||
init_components,
|
init_components,
|
||||||
snapshot: SnapshotAssertion,
|
snapshot: SnapshotAssertion,
|
||||||
@ -395,7 +395,7 @@ async def test_pipeline_from_audio_stream_wake_word(
|
|||||||
|
|
||||||
async def test_pipeline_save_audio(
|
async def test_pipeline_save_audio(
|
||||||
hass: HomeAssistant,
|
hass: HomeAssistant,
|
||||||
mock_stt_provider: MockSttProvider,
|
mock_stt_provider: MockSTTProvider,
|
||||||
mock_wake_word_provider_entity: MockWakeWordEntity,
|
mock_wake_word_provider_entity: MockWakeWordEntity,
|
||||||
init_supporting_components,
|
init_supporting_components,
|
||||||
snapshot: SnapshotAssertion,
|
snapshot: SnapshotAssertion,
|
||||||
@ -474,7 +474,7 @@ async def test_pipeline_save_audio(
|
|||||||
|
|
||||||
async def test_pipeline_saved_audio_with_device_id(
|
async def test_pipeline_saved_audio_with_device_id(
|
||||||
hass: HomeAssistant,
|
hass: HomeAssistant,
|
||||||
mock_stt_provider: MockSttProvider,
|
mock_stt_provider: MockSTTProvider,
|
||||||
mock_wake_word_provider_entity: MockWakeWordEntity,
|
mock_wake_word_provider_entity: MockWakeWordEntity,
|
||||||
init_supporting_components,
|
init_supporting_components,
|
||||||
snapshot: SnapshotAssertion,
|
snapshot: SnapshotAssertion,
|
||||||
@ -529,7 +529,7 @@ async def test_pipeline_saved_audio_with_device_id(
|
|||||||
|
|
||||||
async def test_pipeline_saved_audio_write_error(
|
async def test_pipeline_saved_audio_write_error(
|
||||||
hass: HomeAssistant,
|
hass: HomeAssistant,
|
||||||
mock_stt_provider: MockSttProvider,
|
mock_stt_provider: MockSTTProvider,
|
||||||
mock_wake_word_provider_entity: MockWakeWordEntity,
|
mock_wake_word_provider_entity: MockWakeWordEntity,
|
||||||
init_supporting_components,
|
init_supporting_components,
|
||||||
snapshot: SnapshotAssertion,
|
snapshot: SnapshotAssertion,
|
||||||
@ -578,7 +578,7 @@ async def test_pipeline_saved_audio_write_error(
|
|||||||
|
|
||||||
async def test_pipeline_saved_audio_empty_queue(
|
async def test_pipeline_saved_audio_empty_queue(
|
||||||
hass: HomeAssistant,
|
hass: HomeAssistant,
|
||||||
mock_stt_provider: MockSttProvider,
|
mock_stt_provider: MockSTTProvider,
|
||||||
mock_wake_word_provider_entity: MockWakeWordEntity,
|
mock_wake_word_provider_entity: MockWakeWordEntity,
|
||||||
init_supporting_components,
|
init_supporting_components,
|
||||||
snapshot: SnapshotAssertion,
|
snapshot: SnapshotAssertion,
|
||||||
@ -641,7 +641,7 @@ async def test_pipeline_saved_audio_empty_queue(
|
|||||||
|
|
||||||
async def test_wake_word_detection_aborted(
|
async def test_wake_word_detection_aborted(
|
||||||
hass: HomeAssistant,
|
hass: HomeAssistant,
|
||||||
mock_stt_provider: MockSttProvider,
|
mock_stt_provider: MockSTTProvider,
|
||||||
mock_wake_word_provider_entity: MockWakeWordEntity,
|
mock_wake_word_provider_entity: MockWakeWordEntity,
|
||||||
init_components,
|
init_components,
|
||||||
pipeline_data: assist_pipeline.pipeline.PipelineData,
|
pipeline_data: assist_pipeline.pipeline.PipelineData,
|
||||||
|
@ -26,7 +26,7 @@ from homeassistant.core import HomeAssistant
|
|||||||
from homeassistant.setup import async_setup_component
|
from homeassistant.setup import async_setup_component
|
||||||
|
|
||||||
from . import MANY_LANGUAGES
|
from . import MANY_LANGUAGES
|
||||||
from .conftest import MockSttProviderEntity, MockTTSProvider
|
from .conftest import MockSTTProviderEntity, MockTTSProvider
|
||||||
|
|
||||||
from tests.common import flush_store
|
from tests.common import flush_store
|
||||||
|
|
||||||
@ -398,7 +398,7 @@ async def test_default_pipeline_no_stt_tts(
|
|||||||
@pytest.mark.usefixtures("init_supporting_components")
|
@pytest.mark.usefixtures("init_supporting_components")
|
||||||
async def test_default_pipeline(
|
async def test_default_pipeline(
|
||||||
hass: HomeAssistant,
|
hass: HomeAssistant,
|
||||||
mock_stt_provider_entity: MockSttProviderEntity,
|
mock_stt_provider_entity: MockSTTProviderEntity,
|
||||||
mock_tts_provider: MockTTSProvider,
|
mock_tts_provider: MockTTSProvider,
|
||||||
ha_language: str,
|
ha_language: str,
|
||||||
ha_country: str | None,
|
ha_country: str | None,
|
||||||
@ -441,7 +441,7 @@ async def test_default_pipeline(
|
|||||||
|
|
||||||
@pytest.mark.usefixtures("init_supporting_components")
|
@pytest.mark.usefixtures("init_supporting_components")
|
||||||
async def test_default_pipeline_unsupported_stt_language(
|
async def test_default_pipeline_unsupported_stt_language(
|
||||||
hass: HomeAssistant, mock_stt_provider_entity: MockSttProviderEntity
|
hass: HomeAssistant, mock_stt_provider_entity: MockSTTProviderEntity
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Test async_get_pipeline."""
|
"""Test async_get_pipeline."""
|
||||||
with patch.object(mock_stt_provider_entity, "_supported_languages", ["smurfish"]):
|
with patch.object(mock_stt_provider_entity, "_supported_languages", ["smurfish"]):
|
||||||
|
@ -743,7 +743,7 @@ async def test_stt_stream_failed(
|
|||||||
client = await hass_ws_client(hass)
|
client = await hass_ws_client(hass)
|
||||||
|
|
||||||
with patch(
|
with patch(
|
||||||
"tests.components.assist_pipeline.conftest.MockSttProviderEntity.async_process_audio_stream",
|
"tests.components.assist_pipeline.conftest.MockSTTProviderEntity.async_process_audio_stream",
|
||||||
side_effect=RuntimeError,
|
side_effect=RuntimeError,
|
||||||
):
|
):
|
||||||
await client.send_json_auto_id(
|
await client.send_json_auto_id(
|
||||||
|
@ -2,11 +2,22 @@
|
|||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from collections.abc import Callable, Coroutine
|
from collections.abc import AsyncIterable, Callable, Coroutine
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from homeassistant.components.stt import Provider
|
from homeassistant.components.stt import (
|
||||||
|
AudioBitRates,
|
||||||
|
AudioChannels,
|
||||||
|
AudioCodecs,
|
||||||
|
AudioFormats,
|
||||||
|
AudioSampleRates,
|
||||||
|
Provider,
|
||||||
|
SpeechMetadata,
|
||||||
|
SpeechResult,
|
||||||
|
SpeechResultState,
|
||||||
|
SpeechToTextEntity,
|
||||||
|
)
|
||||||
from homeassistant.config_entries import ConfigEntry
|
from homeassistant.config_entries import ConfigEntry
|
||||||
from homeassistant.core import HomeAssistant
|
from homeassistant.core import HomeAssistant
|
||||||
from homeassistant.helpers.entity_platform import AddEntitiesCallback
|
from homeassistant.helpers.entity_platform import AddEntitiesCallback
|
||||||
@ -14,6 +25,80 @@ from homeassistant.helpers.typing import ConfigType, DiscoveryInfoType
|
|||||||
|
|
||||||
from tests.common import MockPlatform, mock_platform
|
from tests.common import MockPlatform, mock_platform
|
||||||
|
|
||||||
|
TEST_DOMAIN = "test"
|
||||||
|
|
||||||
|
|
||||||
|
class BaseProvider:
|
||||||
|
"""Mock STT provider."""
|
||||||
|
|
||||||
|
fail_process_audio = False
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self, *, supported_languages: list[str] | None = None, text: str = "test_result"
|
||||||
|
) -> None:
|
||||||
|
"""Init test provider."""
|
||||||
|
self._supported_languages = supported_languages or ["de", "de-CH", "en"]
|
||||||
|
self.calls: list[tuple[SpeechMetadata, AsyncIterable[bytes]]] = []
|
||||||
|
self.received: list[bytes] = []
|
||||||
|
self.text = text
|
||||||
|
|
||||||
|
@property
|
||||||
|
def supported_languages(self) -> list[str]:
|
||||||
|
"""Return a list of supported languages."""
|
||||||
|
return self._supported_languages
|
||||||
|
|
||||||
|
@property
|
||||||
|
def supported_formats(self) -> list[AudioFormats]:
|
||||||
|
"""Return a list of supported formats."""
|
||||||
|
return [AudioFormats.WAV, AudioFormats.OGG]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def supported_codecs(self) -> list[AudioCodecs]:
|
||||||
|
"""Return a list of supported codecs."""
|
||||||
|
return [AudioCodecs.PCM, AudioCodecs.OPUS]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def supported_bit_rates(self) -> list[AudioBitRates]:
|
||||||
|
"""Return a list of supported bitrates."""
|
||||||
|
return [AudioBitRates.BITRATE_16]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def supported_sample_rates(self) -> list[AudioSampleRates]:
|
||||||
|
"""Return a list of supported samplerates."""
|
||||||
|
return [AudioSampleRates.SAMPLERATE_16000]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def supported_channels(self) -> list[AudioChannels]:
|
||||||
|
"""Return a list of supported channels."""
|
||||||
|
return [AudioChannels.CHANNEL_MONO]
|
||||||
|
|
||||||
|
async def async_process_audio_stream(
|
||||||
|
self, metadata: SpeechMetadata, stream: AsyncIterable[bytes]
|
||||||
|
) -> SpeechResult:
|
||||||
|
"""Process an audio stream."""
|
||||||
|
self.calls.append((metadata, stream))
|
||||||
|
async for data in stream:
|
||||||
|
if not data:
|
||||||
|
break
|
||||||
|
self.received.append(data)
|
||||||
|
if self.fail_process_audio:
|
||||||
|
return SpeechResult(None, SpeechResultState.ERROR)
|
||||||
|
|
||||||
|
return SpeechResult(self.text, SpeechResultState.SUCCESS)
|
||||||
|
|
||||||
|
|
||||||
|
class MockSTTProvider(BaseProvider, Provider):
|
||||||
|
"""Mock provider."""
|
||||||
|
|
||||||
|
url_path = TEST_DOMAIN
|
||||||
|
|
||||||
|
|
||||||
|
class MockSTTProviderEntity(BaseProvider, SpeechToTextEntity):
|
||||||
|
"""Mock provider entity."""
|
||||||
|
|
||||||
|
url_path = "stt.test"
|
||||||
|
_attr_name = "test"
|
||||||
|
|
||||||
|
|
||||||
class MockSTTPlatform(MockPlatform):
|
class MockSTTPlatform(MockPlatform):
|
||||||
"""Help to set up test stt service."""
|
"""Help to set up test stt service."""
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
"""Test STT component setup."""
|
"""Test STT component setup."""
|
||||||
|
|
||||||
from collections.abc import AsyncIterable, Generator, Iterable
|
from collections.abc import Generator, Iterable
|
||||||
from contextlib import ExitStack
|
from contextlib import ExitStack
|
||||||
from http import HTTPStatus
|
from http import HTTPStatus
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
@ -10,16 +10,6 @@ import pytest
|
|||||||
|
|
||||||
from homeassistant.components.stt import (
|
from homeassistant.components.stt import (
|
||||||
DOMAIN,
|
DOMAIN,
|
||||||
AudioBitRates,
|
|
||||||
AudioChannels,
|
|
||||||
AudioCodecs,
|
|
||||||
AudioFormats,
|
|
||||||
AudioSampleRates,
|
|
||||||
Provider,
|
|
||||||
SpeechMetadata,
|
|
||||||
SpeechResult,
|
|
||||||
SpeechResultState,
|
|
||||||
SpeechToTextEntity,
|
|
||||||
async_default_engine,
|
async_default_engine,
|
||||||
async_get_provider,
|
async_get_provider,
|
||||||
async_get_speech_to_text_engine,
|
async_get_speech_to_text_engine,
|
||||||
@ -29,7 +19,13 @@ from homeassistant.core import HomeAssistant, State
|
|||||||
from homeassistant.helpers.entity_platform import AddEntitiesCallback
|
from homeassistant.helpers.entity_platform import AddEntitiesCallback
|
||||||
from homeassistant.setup import async_setup_component
|
from homeassistant.setup import async_setup_component
|
||||||
|
|
||||||
from .common import mock_stt_entity_platform, mock_stt_platform
|
from .common import (
|
||||||
|
TEST_DOMAIN,
|
||||||
|
MockSTTProvider,
|
||||||
|
MockSTTProviderEntity,
|
||||||
|
mock_stt_entity_platform,
|
||||||
|
mock_stt_platform,
|
||||||
|
)
|
||||||
|
|
||||||
from tests.common import (
|
from tests.common import (
|
||||||
MockConfigEntry,
|
MockConfigEntry,
|
||||||
@ -41,82 +37,17 @@ from tests.common import (
|
|||||||
)
|
)
|
||||||
from tests.typing import ClientSessionGenerator, WebSocketGenerator
|
from tests.typing import ClientSessionGenerator, WebSocketGenerator
|
||||||
|
|
||||||
TEST_DOMAIN = "test"
|
|
||||||
|
|
||||||
|
|
||||||
class BaseProvider:
|
|
||||||
"""Mock provider."""
|
|
||||||
|
|
||||||
fail_process_audio = False
|
|
||||||
|
|
||||||
def __init__(self) -> None:
|
|
||||||
"""Init test provider."""
|
|
||||||
self.calls: list[tuple[SpeechMetadata, AsyncIterable[bytes]]] = []
|
|
||||||
|
|
||||||
@property
|
|
||||||
def supported_languages(self) -> list[str]:
|
|
||||||
"""Return a list of supported languages."""
|
|
||||||
return ["de", "de-CH", "en"]
|
|
||||||
|
|
||||||
@property
|
|
||||||
def supported_formats(self) -> list[AudioFormats]:
|
|
||||||
"""Return a list of supported formats."""
|
|
||||||
return [AudioFormats.WAV, AudioFormats.OGG]
|
|
||||||
|
|
||||||
@property
|
|
||||||
def supported_codecs(self) -> list[AudioCodecs]:
|
|
||||||
"""Return a list of supported codecs."""
|
|
||||||
return [AudioCodecs.PCM, AudioCodecs.OPUS]
|
|
||||||
|
|
||||||
@property
|
|
||||||
def supported_bit_rates(self) -> list[AudioBitRates]:
|
|
||||||
"""Return a list of supported bitrates."""
|
|
||||||
return [AudioBitRates.BITRATE_16]
|
|
||||||
|
|
||||||
@property
|
|
||||||
def supported_sample_rates(self) -> list[AudioSampleRates]:
|
|
||||||
"""Return a list of supported samplerates."""
|
|
||||||
return [AudioSampleRates.SAMPLERATE_16000]
|
|
||||||
|
|
||||||
@property
|
|
||||||
def supported_channels(self) -> list[AudioChannels]:
|
|
||||||
"""Return a list of supported channels."""
|
|
||||||
return [AudioChannels.CHANNEL_MONO]
|
|
||||||
|
|
||||||
async def async_process_audio_stream(
|
|
||||||
self, metadata: SpeechMetadata, stream: AsyncIterable[bytes]
|
|
||||||
) -> SpeechResult:
|
|
||||||
"""Process an audio stream."""
|
|
||||||
self.calls.append((metadata, stream))
|
|
||||||
if self.fail_process_audio:
|
|
||||||
return SpeechResult(None, SpeechResultState.ERROR)
|
|
||||||
|
|
||||||
return SpeechResult("test_result", SpeechResultState.SUCCESS)
|
|
||||||
|
|
||||||
|
|
||||||
class MockProvider(BaseProvider, Provider):
|
|
||||||
"""Mock provider."""
|
|
||||||
|
|
||||||
url_path = TEST_DOMAIN
|
|
||||||
|
|
||||||
|
|
||||||
class MockProviderEntity(BaseProvider, SpeechToTextEntity):
|
|
||||||
"""Mock provider entity."""
|
|
||||||
|
|
||||||
url_path = "stt.test"
|
|
||||||
_attr_name = "test"
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def mock_provider() -> MockProvider:
|
def mock_provider() -> MockSTTProvider:
|
||||||
"""Test provider fixture."""
|
"""Test provider fixture."""
|
||||||
return MockProvider()
|
return MockSTTProvider()
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def mock_provider_entity() -> MockProviderEntity:
|
def mock_provider_entity() -> MockSTTProviderEntity:
|
||||||
"""Test provider entity fixture."""
|
"""Test provider entity fixture."""
|
||||||
return MockProviderEntity()
|
return MockSTTProviderEntity()
|
||||||
|
|
||||||
|
|
||||||
class STTFlow(ConfigFlow):
|
class STTFlow(ConfigFlow):
|
||||||
@ -148,14 +79,14 @@ async def setup_fixture(
|
|||||||
hass: HomeAssistant,
|
hass: HomeAssistant,
|
||||||
tmp_path: Path,
|
tmp_path: Path,
|
||||||
request: pytest.FixtureRequest,
|
request: pytest.FixtureRequest,
|
||||||
) -> MockProvider | MockProviderEntity:
|
) -> MockSTTProvider | MockSTTProviderEntity:
|
||||||
"""Set up the test environment."""
|
"""Set up the test environment."""
|
||||||
provider: MockProvider | MockProviderEntity
|
provider: MockSTTProvider | MockSTTProviderEntity
|
||||||
if request.param == "mock_setup":
|
if request.param == "mock_setup":
|
||||||
provider = MockProvider()
|
provider = MockSTTProvider()
|
||||||
await mock_setup(hass, tmp_path, provider)
|
await mock_setup(hass, tmp_path, provider)
|
||||||
elif request.param == "mock_config_entry_setup":
|
elif request.param == "mock_config_entry_setup":
|
||||||
provider = MockProviderEntity()
|
provider = MockSTTProviderEntity()
|
||||||
await mock_config_entry_setup(hass, tmp_path, provider)
|
await mock_config_entry_setup(hass, tmp_path, provider)
|
||||||
else:
|
else:
|
||||||
raise RuntimeError("Invalid setup fixture")
|
raise RuntimeError("Invalid setup fixture")
|
||||||
@ -166,7 +97,7 @@ async def setup_fixture(
|
|||||||
async def mock_setup(
|
async def mock_setup(
|
||||||
hass: HomeAssistant,
|
hass: HomeAssistant,
|
||||||
tmp_path: Path,
|
tmp_path: Path,
|
||||||
mock_provider: MockProvider,
|
mock_provider: MockSTTProvider,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Set up a test provider."""
|
"""Set up a test provider."""
|
||||||
mock_stt_platform(
|
mock_stt_platform(
|
||||||
@ -182,7 +113,7 @@ async def mock_setup(
|
|||||||
async def mock_config_entry_setup(
|
async def mock_config_entry_setup(
|
||||||
hass: HomeAssistant,
|
hass: HomeAssistant,
|
||||||
tmp_path: Path,
|
tmp_path: Path,
|
||||||
mock_provider_entity: MockProviderEntity,
|
mock_provider_entity: MockSTTProviderEntity,
|
||||||
test_domain: str = TEST_DOMAIN,
|
test_domain: str = TEST_DOMAIN,
|
||||||
) -> MockConfigEntry:
|
) -> MockConfigEntry:
|
||||||
"""Set up a test provider via config entry."""
|
"""Set up a test provider via config entry."""
|
||||||
@ -234,7 +165,7 @@ async def mock_config_entry_setup(
|
|||||||
async def test_get_provider_info(
|
async def test_get_provider_info(
|
||||||
hass: HomeAssistant,
|
hass: HomeAssistant,
|
||||||
hass_client: ClientSessionGenerator,
|
hass_client: ClientSessionGenerator,
|
||||||
setup: MockProvider | MockProviderEntity,
|
setup: MockSTTProvider | MockSTTProviderEntity,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Test engine that doesn't exist."""
|
"""Test engine that doesn't exist."""
|
||||||
client = await hass_client()
|
client = await hass_client()
|
||||||
@ -256,7 +187,7 @@ async def test_get_provider_info(
|
|||||||
async def test_non_existing_provider(
|
async def test_non_existing_provider(
|
||||||
hass: HomeAssistant,
|
hass: HomeAssistant,
|
||||||
hass_client: ClientSessionGenerator,
|
hass_client: ClientSessionGenerator,
|
||||||
setup: MockProvider | MockProviderEntity,
|
setup: MockSTTProvider | MockSTTProviderEntity,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Test streaming to engine that doesn't exist."""
|
"""Test streaming to engine that doesn't exist."""
|
||||||
client = await hass_client()
|
client = await hass_client()
|
||||||
@ -282,7 +213,7 @@ async def test_non_existing_provider(
|
|||||||
async def test_stream_audio(
|
async def test_stream_audio(
|
||||||
hass: HomeAssistant,
|
hass: HomeAssistant,
|
||||||
hass_client: ClientSessionGenerator,
|
hass_client: ClientSessionGenerator,
|
||||||
setup: MockProvider | MockProviderEntity,
|
setup: MockSTTProvider | MockSTTProviderEntity,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Test streaming audio and getting response."""
|
"""Test streaming audio and getting response."""
|
||||||
client = await hass_client()
|
client = await hass_client()
|
||||||
@ -343,7 +274,7 @@ async def test_metadata_errors(
|
|||||||
header: str | None,
|
header: str | None,
|
||||||
status: int,
|
status: int,
|
||||||
error: str,
|
error: str,
|
||||||
setup: MockProvider | MockProviderEntity,
|
setup: MockSTTProvider | MockSTTProviderEntity,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Test metadata errors."""
|
"""Test metadata errors."""
|
||||||
client = await hass_client()
|
client = await hass_client()
|
||||||
@ -359,7 +290,7 @@ async def test_metadata_errors(
|
|||||||
async def test_get_provider(
|
async def test_get_provider(
|
||||||
hass: HomeAssistant,
|
hass: HomeAssistant,
|
||||||
tmp_path: Path,
|
tmp_path: Path,
|
||||||
mock_provider: MockProvider,
|
mock_provider: MockSTTProvider,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Test we can get STT providers."""
|
"""Test we can get STT providers."""
|
||||||
await mock_setup(hass, tmp_path, mock_provider)
|
await mock_setup(hass, tmp_path, mock_provider)
|
||||||
@ -370,7 +301,7 @@ async def test_get_provider(
|
|||||||
|
|
||||||
|
|
||||||
async def test_config_entry_unload(
|
async def test_config_entry_unload(
|
||||||
hass: HomeAssistant, tmp_path: Path, mock_provider_entity: MockProviderEntity
|
hass: HomeAssistant, tmp_path: Path, mock_provider_entity: MockSTTProviderEntity
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Test we can unload config entry."""
|
"""Test we can unload config entry."""
|
||||||
config_entry = await mock_config_entry_setup(hass, tmp_path, mock_provider_entity)
|
config_entry = await mock_config_entry_setup(hass, tmp_path, mock_provider_entity)
|
||||||
@ -382,7 +313,7 @@ async def test_config_entry_unload(
|
|||||||
async def test_restore_state(
|
async def test_restore_state(
|
||||||
hass: HomeAssistant,
|
hass: HomeAssistant,
|
||||||
tmp_path: Path,
|
tmp_path: Path,
|
||||||
mock_provider_entity: MockProviderEntity,
|
mock_provider_entity: MockSTTProviderEntity,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Test we restore state in the integration."""
|
"""Test we restore state in the integration."""
|
||||||
entity_id = f"{DOMAIN}.{TEST_DOMAIN}"
|
entity_id = f"{DOMAIN}.{TEST_DOMAIN}"
|
||||||
@ -409,7 +340,7 @@ async def test_restore_state(
|
|||||||
async def test_ws_list_engines(
|
async def test_ws_list_engines(
|
||||||
hass: HomeAssistant,
|
hass: HomeAssistant,
|
||||||
hass_ws_client: WebSocketGenerator,
|
hass_ws_client: WebSocketGenerator,
|
||||||
setup: MockProvider | MockProviderEntity,
|
setup: MockSTTProvider | MockSTTProviderEntity,
|
||||||
engine_id: str,
|
engine_id: str,
|
||||||
extra_data: dict[str, str],
|
extra_data: dict[str, str],
|
||||||
) -> None:
|
) -> None:
|
||||||
@ -491,7 +422,7 @@ async def test_default_engine_none(hass: HomeAssistant, tmp_path: Path) -> None:
|
|||||||
async def test_default_engine(
|
async def test_default_engine(
|
||||||
hass: HomeAssistant,
|
hass: HomeAssistant,
|
||||||
tmp_path: Path,
|
tmp_path: Path,
|
||||||
mock_provider: MockProvider,
|
mock_provider: MockSTTProvider,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Test async_default_engine."""
|
"""Test async_default_engine."""
|
||||||
mock_stt_platform(
|
mock_stt_platform(
|
||||||
@ -507,7 +438,7 @@ async def test_default_engine(
|
|||||||
|
|
||||||
|
|
||||||
async def test_default_engine_entity(
|
async def test_default_engine_entity(
|
||||||
hass: HomeAssistant, tmp_path: Path, mock_provider_entity: MockProviderEntity
|
hass: HomeAssistant, tmp_path: Path, mock_provider_entity: MockSTTProviderEntity
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Test async_default_engine."""
|
"""Test async_default_engine."""
|
||||||
await mock_config_entry_setup(hass, tmp_path, mock_provider_entity)
|
await mock_config_entry_setup(hass, tmp_path, mock_provider_entity)
|
||||||
@ -519,8 +450,8 @@ async def test_default_engine_entity(
|
|||||||
async def test_default_engine_prefer_entity(
|
async def test_default_engine_prefer_entity(
|
||||||
hass: HomeAssistant,
|
hass: HomeAssistant,
|
||||||
tmp_path: Path,
|
tmp_path: Path,
|
||||||
mock_provider_entity: MockProviderEntity,
|
mock_provider_entity: MockSTTProviderEntity,
|
||||||
mock_provider: MockProvider,
|
mock_provider: MockSTTProvider,
|
||||||
config_flow_test_domains: str,
|
config_flow_test_domains: str,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Test async_default_engine.
|
"""Test async_default_engine.
|
||||||
@ -558,7 +489,7 @@ async def test_default_engine_prefer_entity(
|
|||||||
async def test_default_engine_prefer_cloud_entity(
|
async def test_default_engine_prefer_cloud_entity(
|
||||||
hass: HomeAssistant,
|
hass: HomeAssistant,
|
||||||
tmp_path: Path,
|
tmp_path: Path,
|
||||||
mock_provider: MockProvider,
|
mock_provider: MockSTTProvider,
|
||||||
config_flow_test_domains: str,
|
config_flow_test_domains: str,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Test async_default_engine.
|
"""Test async_default_engine.
|
||||||
@ -569,7 +500,7 @@ async def test_default_engine_prefer_cloud_entity(
|
|||||||
"""
|
"""
|
||||||
await mock_setup(hass, tmp_path, mock_provider)
|
await mock_setup(hass, tmp_path, mock_provider)
|
||||||
for domain in config_flow_test_domains:
|
for domain in config_flow_test_domains:
|
||||||
entity = MockProviderEntity()
|
entity = MockSTTProviderEntity()
|
||||||
entity.url_path = f"stt.{domain}"
|
entity.url_path = f"stt.{domain}"
|
||||||
entity._attr_name = f"{domain} STT entity"
|
entity._attr_name = f"{domain} STT entity"
|
||||||
await mock_config_entry_setup(hass, tmp_path, entity, test_domain=domain)
|
await mock_config_entry_setup(hass, tmp_path, entity, test_domain=domain)
|
||||||
@ -589,7 +520,7 @@ async def test_default_engine_prefer_cloud_entity(
|
|||||||
|
|
||||||
|
|
||||||
async def test_get_engine_legacy(
|
async def test_get_engine_legacy(
|
||||||
hass: HomeAssistant, tmp_path: Path, mock_provider: MockProvider
|
hass: HomeAssistant, tmp_path: Path, mock_provider: MockSTTProvider
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Test async_get_speech_to_text_engine."""
|
"""Test async_get_speech_to_text_engine."""
|
||||||
mock_stt_platform(
|
mock_stt_platform(
|
||||||
@ -614,7 +545,7 @@ async def test_get_engine_legacy(
|
|||||||
|
|
||||||
|
|
||||||
async def test_get_engine_entity(
|
async def test_get_engine_entity(
|
||||||
hass: HomeAssistant, tmp_path: Path, mock_provider_entity: MockProviderEntity
|
hass: HomeAssistant, tmp_path: Path, mock_provider_entity: MockSTTProviderEntity
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Test async_get_speech_to_text_engine."""
|
"""Test async_get_speech_to_text_engine."""
|
||||||
await mock_config_entry_setup(hass, tmp_path, mock_provider_entity)
|
await mock_config_entry_setup(hass, tmp_path, mock_provider_entity)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user