diff --git a/tests/components/assist_pipeline/conftest.py b/tests/components/assist_pipeline/conftest.py index b7bf83a7ed0..c03874c16af 100644 --- a/tests/components/assist_pipeline/conftest.py +++ b/tests/components/assist_pipeline/conftest.py @@ -36,6 +36,7 @@ from tests.common import ( mock_integration, mock_platform, ) +from tests.components.stt.common import MockSTTProvider, MockSTTProviderEntity _TRANSCRIPT = "test transcript" @@ -47,67 +48,6 @@ def mock_tts_cache_dir_autouse(mock_tts_cache_dir: Path) -> None: """Mock the TTS cache dir with empty dir.""" -class BaseProvider: - """Mock STT provider.""" - - _supported_languages = ["en-US"] - - def __init__(self, text: str) -> None: - """Init test provider.""" - self.text = text - self.received: list[bytes] = [] - - @property - def supported_languages(self) -> list[str]: - """Return a list of supported languages.""" - return self._supported_languages - - @property - def supported_formats(self) -> list[stt.AudioFormats]: - """Return a list of supported formats.""" - return [stt.AudioFormats.WAV] - - @property - def supported_codecs(self) -> list[stt.AudioCodecs]: - """Return a list of supported codecs.""" - return [stt.AudioCodecs.PCM] - - @property - def supported_bit_rates(self) -> list[stt.AudioBitRates]: - """Return a list of supported bitrates.""" - return [stt.AudioBitRates.BITRATE_16] - - @property - def supported_sample_rates(self) -> list[stt.AudioSampleRates]: - """Return a list of supported samplerates.""" - return [stt.AudioSampleRates.SAMPLERATE_16000] - - @property - def supported_channels(self) -> list[stt.AudioChannels]: - """Return a list of supported channels.""" - return [stt.AudioChannels.CHANNEL_MONO] - - async def async_process_audio_stream( - self, metadata: stt.SpeechMetadata, stream: AsyncIterable[bytes] - ) -> stt.SpeechResult: - """Process an audio stream.""" - async for data in stream: - if not data: - break - self.received.append(data) - return stt.SpeechResult(self.text, stt.SpeechResultState.SUCCESS) - - -class MockSttProvider(BaseProvider, stt.Provider): - """Mock provider.""" - - -class MockSttProviderEntity(BaseProvider, stt.SpeechToTextEntity): - """Mock provider entity.""" - - _attr_name = "Mock STT" - - class MockTTSProvider(tts.Provider): """Mock TTS provider.""" @@ -166,15 +106,17 @@ async def mock_tts_provider() -> MockTTSProvider: @pytest.fixture -async def mock_stt_provider() -> MockSttProvider: +async def mock_stt_provider() -> MockSTTProvider: """Mock STT provider.""" - return MockSttProvider(_TRANSCRIPT) + return MockSTTProvider(supported_languages=["en-US"], text=_TRANSCRIPT) @pytest.fixture -def mock_stt_provider_entity() -> MockSttProviderEntity: +def mock_stt_provider_entity() -> MockSTTProviderEntity: """Test provider entity fixture.""" - return MockSttProviderEntity(_TRANSCRIPT) + entity = MockSTTProviderEntity(supported_languages=["en-US"], text=_TRANSCRIPT) + entity._attr_name = "Mock STT" + return entity class MockSttPlatform(MockPlatform): @@ -290,8 +232,8 @@ def config_flow_fixture(hass: HomeAssistant) -> Generator[None]: @pytest.fixture async def init_supporting_components( hass: HomeAssistant, - mock_stt_provider: MockSttProvider, - mock_stt_provider_entity: MockSttProviderEntity, + mock_stt_provider: MockSTTProvider, + mock_stt_provider_entity: MockSTTProviderEntity, mock_tts_provider: MockTTSProvider, mock_wake_word_provider_entity: MockWakeWordEntity, mock_wake_word_provider_entity2: MockWakeWordEntity2, diff --git a/tests/components/assist_pipeline/test_init.py b/tests/components/assist_pipeline/test_init.py index 04edab7131f..31cc1268098 100644 --- a/tests/components/assist_pipeline/test_init.py +++ b/tests/components/assist_pipeline/test_init.py @@ -22,8 +22,8 @@ from homeassistant.setup import async_setup_component from .conftest import ( BYTES_ONE_SECOND, - MockSttProvider, - MockSttProviderEntity, + MockSTTProvider, + MockSTTProviderEntity, MockTTSProvider, MockWakeWordEntity, make_10ms_chunk, @@ -47,7 +47,7 @@ def process_events(events: list[assist_pipeline.PipelineEvent]) -> list[dict]: async def test_pipeline_from_audio_stream_auto( hass: HomeAssistant, - mock_stt_provider_entity: MockSttProviderEntity, + mock_stt_provider_entity: MockSTTProviderEntity, init_components, snapshot: SnapshotAssertion, ) -> None: @@ -88,7 +88,7 @@ async def test_pipeline_from_audio_stream_auto( async def test_pipeline_from_audio_stream_legacy( hass: HomeAssistant, hass_ws_client: WebSocketGenerator, - mock_stt_provider: MockSttProvider, + mock_stt_provider: MockSTTProvider, init_components, snapshot: SnapshotAssertion, ) -> None: @@ -153,7 +153,7 @@ async def test_pipeline_from_audio_stream_legacy( async def test_pipeline_from_audio_stream_entity( hass: HomeAssistant, hass_ws_client: WebSocketGenerator, - mock_stt_provider_entity: MockSttProviderEntity, + mock_stt_provider_entity: MockSTTProviderEntity, init_components, snapshot: SnapshotAssertion, ) -> None: @@ -218,7 +218,7 @@ async def test_pipeline_from_audio_stream_entity( async def test_pipeline_from_audio_stream_no_stt( hass: HomeAssistant, hass_ws_client: WebSocketGenerator, - mock_stt_provider: MockSttProvider, + mock_stt_provider: MockSTTProvider, init_components, snapshot: SnapshotAssertion, ) -> None: @@ -281,7 +281,7 @@ async def test_pipeline_from_audio_stream_no_stt( async def test_pipeline_from_audio_stream_unknown_pipeline( hass: HomeAssistant, hass_ws_client: WebSocketGenerator, - mock_stt_provider: MockSttProvider, + mock_stt_provider: MockSTTProvider, init_components, snapshot: SnapshotAssertion, ) -> None: @@ -319,7 +319,7 @@ async def test_pipeline_from_audio_stream_unknown_pipeline( async def test_pipeline_from_audio_stream_wake_word( hass: HomeAssistant, - mock_stt_provider_entity: MockSttProviderEntity, + mock_stt_provider_entity: MockSTTProviderEntity, mock_wake_word_provider_entity: MockWakeWordEntity, init_components, snapshot: SnapshotAssertion, @@ -395,7 +395,7 @@ async def test_pipeline_from_audio_stream_wake_word( async def test_pipeline_save_audio( hass: HomeAssistant, - mock_stt_provider: MockSttProvider, + mock_stt_provider: MockSTTProvider, mock_wake_word_provider_entity: MockWakeWordEntity, init_supporting_components, snapshot: SnapshotAssertion, @@ -474,7 +474,7 @@ async def test_pipeline_save_audio( async def test_pipeline_saved_audio_with_device_id( hass: HomeAssistant, - mock_stt_provider: MockSttProvider, + mock_stt_provider: MockSTTProvider, mock_wake_word_provider_entity: MockWakeWordEntity, init_supporting_components, snapshot: SnapshotAssertion, @@ -529,7 +529,7 @@ async def test_pipeline_saved_audio_with_device_id( async def test_pipeline_saved_audio_write_error( hass: HomeAssistant, - mock_stt_provider: MockSttProvider, + mock_stt_provider: MockSTTProvider, mock_wake_word_provider_entity: MockWakeWordEntity, init_supporting_components, snapshot: SnapshotAssertion, @@ -578,7 +578,7 @@ async def test_pipeline_saved_audio_write_error( async def test_pipeline_saved_audio_empty_queue( hass: HomeAssistant, - mock_stt_provider: MockSttProvider, + mock_stt_provider: MockSTTProvider, mock_wake_word_provider_entity: MockWakeWordEntity, init_supporting_components, snapshot: SnapshotAssertion, @@ -641,7 +641,7 @@ async def test_pipeline_saved_audio_empty_queue( async def test_wake_word_detection_aborted( hass: HomeAssistant, - mock_stt_provider: MockSttProvider, + mock_stt_provider: MockSTTProvider, mock_wake_word_provider_entity: MockWakeWordEntity, init_components, pipeline_data: assist_pipeline.pipeline.PipelineData, diff --git a/tests/components/assist_pipeline/test_pipeline.py b/tests/components/assist_pipeline/test_pipeline.py index ef5d5edff9e..50d0fc9bed8 100644 --- a/tests/components/assist_pipeline/test_pipeline.py +++ b/tests/components/assist_pipeline/test_pipeline.py @@ -26,7 +26,7 @@ from homeassistant.core import HomeAssistant from homeassistant.setup import async_setup_component from . import MANY_LANGUAGES -from .conftest import MockSttProviderEntity, MockTTSProvider +from .conftest import MockSTTProviderEntity, MockTTSProvider from tests.common import flush_store @@ -398,7 +398,7 @@ async def test_default_pipeline_no_stt_tts( @pytest.mark.usefixtures("init_supporting_components") async def test_default_pipeline( hass: HomeAssistant, - mock_stt_provider_entity: MockSttProviderEntity, + mock_stt_provider_entity: MockSTTProviderEntity, mock_tts_provider: MockTTSProvider, ha_language: str, ha_country: str | None, @@ -441,7 +441,7 @@ async def test_default_pipeline( @pytest.mark.usefixtures("init_supporting_components") async def test_default_pipeline_unsupported_stt_language( - hass: HomeAssistant, mock_stt_provider_entity: MockSttProviderEntity + hass: HomeAssistant, mock_stt_provider_entity: MockSTTProviderEntity ) -> None: """Test async_get_pipeline.""" with patch.object(mock_stt_provider_entity, "_supported_languages", ["smurfish"]): diff --git a/tests/components/assist_pipeline/test_websocket.py b/tests/components/assist_pipeline/test_websocket.py index f1f68d4a423..e339ee74fbb 100644 --- a/tests/components/assist_pipeline/test_websocket.py +++ b/tests/components/assist_pipeline/test_websocket.py @@ -743,7 +743,7 @@ async def test_stt_stream_failed( client = await hass_ws_client(hass) with patch( - "tests.components.assist_pipeline.conftest.MockSttProviderEntity.async_process_audio_stream", + "tests.components.assist_pipeline.conftest.MockSTTProviderEntity.async_process_audio_stream", side_effect=RuntimeError, ): await client.send_json_auto_id( diff --git a/tests/components/stt/common.py b/tests/components/stt/common.py index e6c36c5b350..f964fca6b67 100644 --- a/tests/components/stt/common.py +++ b/tests/components/stt/common.py @@ -2,11 +2,22 @@ from __future__ import annotations -from collections.abc import Callable, Coroutine +from collections.abc import AsyncIterable, Callable, Coroutine from pathlib import Path from typing import Any -from homeassistant.components.stt import Provider +from homeassistant.components.stt import ( + AudioBitRates, + AudioChannels, + AudioCodecs, + AudioFormats, + AudioSampleRates, + Provider, + SpeechMetadata, + SpeechResult, + SpeechResultState, + SpeechToTextEntity, +) from homeassistant.config_entries import ConfigEntry from homeassistant.core import HomeAssistant from homeassistant.helpers.entity_platform import AddEntitiesCallback @@ -14,6 +25,80 @@ from homeassistant.helpers.typing import ConfigType, DiscoveryInfoType from tests.common import MockPlatform, mock_platform +TEST_DOMAIN = "test" + + +class BaseProvider: + """Mock STT provider.""" + + fail_process_audio = False + + def __init__( + self, *, supported_languages: list[str] | None = None, text: str = "test_result" + ) -> None: + """Init test provider.""" + self._supported_languages = supported_languages or ["de", "de-CH", "en"] + self.calls: list[tuple[SpeechMetadata, AsyncIterable[bytes]]] = [] + self.received: list[bytes] = [] + self.text = text + + @property + def supported_languages(self) -> list[str]: + """Return a list of supported languages.""" + return self._supported_languages + + @property + def supported_formats(self) -> list[AudioFormats]: + """Return a list of supported formats.""" + return [AudioFormats.WAV, AudioFormats.OGG] + + @property + def supported_codecs(self) -> list[AudioCodecs]: + """Return a list of supported codecs.""" + return [AudioCodecs.PCM, AudioCodecs.OPUS] + + @property + def supported_bit_rates(self) -> list[AudioBitRates]: + """Return a list of supported bitrates.""" + return [AudioBitRates.BITRATE_16] + + @property + def supported_sample_rates(self) -> list[AudioSampleRates]: + """Return a list of supported samplerates.""" + return [AudioSampleRates.SAMPLERATE_16000] + + @property + def supported_channels(self) -> list[AudioChannels]: + """Return a list of supported channels.""" + return [AudioChannels.CHANNEL_MONO] + + async def async_process_audio_stream( + self, metadata: SpeechMetadata, stream: AsyncIterable[bytes] + ) -> SpeechResult: + """Process an audio stream.""" + self.calls.append((metadata, stream)) + async for data in stream: + if not data: + break + self.received.append(data) + if self.fail_process_audio: + return SpeechResult(None, SpeechResultState.ERROR) + + return SpeechResult(self.text, SpeechResultState.SUCCESS) + + +class MockSTTProvider(BaseProvider, Provider): + """Mock provider.""" + + url_path = TEST_DOMAIN + + +class MockSTTProviderEntity(BaseProvider, SpeechToTextEntity): + """Mock provider entity.""" + + url_path = "stt.test" + _attr_name = "test" + class MockSTTPlatform(MockPlatform): """Help to set up test stt service.""" diff --git a/tests/components/stt/test_init.py b/tests/components/stt/test_init.py index 5c98b0f8d57..92225123995 100644 --- a/tests/components/stt/test_init.py +++ b/tests/components/stt/test_init.py @@ -1,6 +1,6 @@ """Test STT component setup.""" -from collections.abc import AsyncIterable, Generator, Iterable +from collections.abc import Generator, Iterable from contextlib import ExitStack from http import HTTPStatus from pathlib import Path @@ -10,16 +10,6 @@ import pytest from homeassistant.components.stt import ( DOMAIN, - AudioBitRates, - AudioChannels, - AudioCodecs, - AudioFormats, - AudioSampleRates, - Provider, - SpeechMetadata, - SpeechResult, - SpeechResultState, - SpeechToTextEntity, async_default_engine, async_get_provider, async_get_speech_to_text_engine, @@ -29,7 +19,13 @@ from homeassistant.core import HomeAssistant, State from homeassistant.helpers.entity_platform import AddEntitiesCallback from homeassistant.setup import async_setup_component -from .common import mock_stt_entity_platform, mock_stt_platform +from .common import ( + TEST_DOMAIN, + MockSTTProvider, + MockSTTProviderEntity, + mock_stt_entity_platform, + mock_stt_platform, +) from tests.common import ( MockConfigEntry, @@ -41,82 +37,17 @@ from tests.common import ( ) from tests.typing import ClientSessionGenerator, WebSocketGenerator -TEST_DOMAIN = "test" - - -class BaseProvider: - """Mock provider.""" - - fail_process_audio = False - - def __init__(self) -> None: - """Init test provider.""" - self.calls: list[tuple[SpeechMetadata, AsyncIterable[bytes]]] = [] - - @property - def supported_languages(self) -> list[str]: - """Return a list of supported languages.""" - return ["de", "de-CH", "en"] - - @property - def supported_formats(self) -> list[AudioFormats]: - """Return a list of supported formats.""" - return [AudioFormats.WAV, AudioFormats.OGG] - - @property - def supported_codecs(self) -> list[AudioCodecs]: - """Return a list of supported codecs.""" - return [AudioCodecs.PCM, AudioCodecs.OPUS] - - @property - def supported_bit_rates(self) -> list[AudioBitRates]: - """Return a list of supported bitrates.""" - return [AudioBitRates.BITRATE_16] - - @property - def supported_sample_rates(self) -> list[AudioSampleRates]: - """Return a list of supported samplerates.""" - return [AudioSampleRates.SAMPLERATE_16000] - - @property - def supported_channels(self) -> list[AudioChannels]: - """Return a list of supported channels.""" - return [AudioChannels.CHANNEL_MONO] - - async def async_process_audio_stream( - self, metadata: SpeechMetadata, stream: AsyncIterable[bytes] - ) -> SpeechResult: - """Process an audio stream.""" - self.calls.append((metadata, stream)) - if self.fail_process_audio: - return SpeechResult(None, SpeechResultState.ERROR) - - return SpeechResult("test_result", SpeechResultState.SUCCESS) - - -class MockProvider(BaseProvider, Provider): - """Mock provider.""" - - url_path = TEST_DOMAIN - - -class MockProviderEntity(BaseProvider, SpeechToTextEntity): - """Mock provider entity.""" - - url_path = "stt.test" - _attr_name = "test" - @pytest.fixture -def mock_provider() -> MockProvider: +def mock_provider() -> MockSTTProvider: """Test provider fixture.""" - return MockProvider() + return MockSTTProvider() @pytest.fixture -def mock_provider_entity() -> MockProviderEntity: +def mock_provider_entity() -> MockSTTProviderEntity: """Test provider entity fixture.""" - return MockProviderEntity() + return MockSTTProviderEntity() class STTFlow(ConfigFlow): @@ -148,14 +79,14 @@ async def setup_fixture( hass: HomeAssistant, tmp_path: Path, request: pytest.FixtureRequest, -) -> MockProvider | MockProviderEntity: +) -> MockSTTProvider | MockSTTProviderEntity: """Set up the test environment.""" - provider: MockProvider | MockProviderEntity + provider: MockSTTProvider | MockSTTProviderEntity if request.param == "mock_setup": - provider = MockProvider() + provider = MockSTTProvider() await mock_setup(hass, tmp_path, provider) elif request.param == "mock_config_entry_setup": - provider = MockProviderEntity() + provider = MockSTTProviderEntity() await mock_config_entry_setup(hass, tmp_path, provider) else: raise RuntimeError("Invalid setup fixture") @@ -166,7 +97,7 @@ async def setup_fixture( async def mock_setup( hass: HomeAssistant, tmp_path: Path, - mock_provider: MockProvider, + mock_provider: MockSTTProvider, ) -> None: """Set up a test provider.""" mock_stt_platform( @@ -182,7 +113,7 @@ async def mock_setup( async def mock_config_entry_setup( hass: HomeAssistant, tmp_path: Path, - mock_provider_entity: MockProviderEntity, + mock_provider_entity: MockSTTProviderEntity, test_domain: str = TEST_DOMAIN, ) -> MockConfigEntry: """Set up a test provider via config entry.""" @@ -234,7 +165,7 @@ async def mock_config_entry_setup( async def test_get_provider_info( hass: HomeAssistant, hass_client: ClientSessionGenerator, - setup: MockProvider | MockProviderEntity, + setup: MockSTTProvider | MockSTTProviderEntity, ) -> None: """Test engine that doesn't exist.""" client = await hass_client() @@ -256,7 +187,7 @@ async def test_get_provider_info( async def test_non_existing_provider( hass: HomeAssistant, hass_client: ClientSessionGenerator, - setup: MockProvider | MockProviderEntity, + setup: MockSTTProvider | MockSTTProviderEntity, ) -> None: """Test streaming to engine that doesn't exist.""" client = await hass_client() @@ -282,7 +213,7 @@ async def test_non_existing_provider( async def test_stream_audio( hass: HomeAssistant, hass_client: ClientSessionGenerator, - setup: MockProvider | MockProviderEntity, + setup: MockSTTProvider | MockSTTProviderEntity, ) -> None: """Test streaming audio and getting response.""" client = await hass_client() @@ -343,7 +274,7 @@ async def test_metadata_errors( header: str | None, status: int, error: str, - setup: MockProvider | MockProviderEntity, + setup: MockSTTProvider | MockSTTProviderEntity, ) -> None: """Test metadata errors.""" client = await hass_client() @@ -359,7 +290,7 @@ async def test_metadata_errors( async def test_get_provider( hass: HomeAssistant, tmp_path: Path, - mock_provider: MockProvider, + mock_provider: MockSTTProvider, ) -> None: """Test we can get STT providers.""" await mock_setup(hass, tmp_path, mock_provider) @@ -370,7 +301,7 @@ async def test_get_provider( async def test_config_entry_unload( - hass: HomeAssistant, tmp_path: Path, mock_provider_entity: MockProviderEntity + hass: HomeAssistant, tmp_path: Path, mock_provider_entity: MockSTTProviderEntity ) -> None: """Test we can unload config entry.""" config_entry = await mock_config_entry_setup(hass, tmp_path, mock_provider_entity) @@ -382,7 +313,7 @@ async def test_config_entry_unload( async def test_restore_state( hass: HomeAssistant, tmp_path: Path, - mock_provider_entity: MockProviderEntity, + mock_provider_entity: MockSTTProviderEntity, ) -> None: """Test we restore state in the integration.""" entity_id = f"{DOMAIN}.{TEST_DOMAIN}" @@ -409,7 +340,7 @@ async def test_restore_state( async def test_ws_list_engines( hass: HomeAssistant, hass_ws_client: WebSocketGenerator, - setup: MockProvider | MockProviderEntity, + setup: MockSTTProvider | MockSTTProviderEntity, engine_id: str, extra_data: dict[str, str], ) -> None: @@ -491,7 +422,7 @@ async def test_default_engine_none(hass: HomeAssistant, tmp_path: Path) -> None: async def test_default_engine( hass: HomeAssistant, tmp_path: Path, - mock_provider: MockProvider, + mock_provider: MockSTTProvider, ) -> None: """Test async_default_engine.""" mock_stt_platform( @@ -507,7 +438,7 @@ async def test_default_engine( async def test_default_engine_entity( - hass: HomeAssistant, tmp_path: Path, mock_provider_entity: MockProviderEntity + hass: HomeAssistant, tmp_path: Path, mock_provider_entity: MockSTTProviderEntity ) -> None: """Test async_default_engine.""" await mock_config_entry_setup(hass, tmp_path, mock_provider_entity) @@ -519,8 +450,8 @@ async def test_default_engine_entity( async def test_default_engine_prefer_entity( hass: HomeAssistant, tmp_path: Path, - mock_provider_entity: MockProviderEntity, - mock_provider: MockProvider, + mock_provider_entity: MockSTTProviderEntity, + mock_provider: MockSTTProvider, config_flow_test_domains: str, ) -> None: """Test async_default_engine. @@ -558,7 +489,7 @@ async def test_default_engine_prefer_entity( async def test_default_engine_prefer_cloud_entity( hass: HomeAssistant, tmp_path: Path, - mock_provider: MockProvider, + mock_provider: MockSTTProvider, config_flow_test_domains: str, ) -> None: """Test async_default_engine. @@ -569,7 +500,7 @@ async def test_default_engine_prefer_cloud_entity( """ await mock_setup(hass, tmp_path, mock_provider) for domain in config_flow_test_domains: - entity = MockProviderEntity() + entity = MockSTTProviderEntity() entity.url_path = f"stt.{domain}" entity._attr_name = f"{domain} STT entity" await mock_config_entry_setup(hass, tmp_path, entity, test_domain=domain) @@ -589,7 +520,7 @@ async def test_default_engine_prefer_cloud_entity( async def test_get_engine_legacy( - hass: HomeAssistant, tmp_path: Path, mock_provider: MockProvider + hass: HomeAssistant, tmp_path: Path, mock_provider: MockSTTProvider ) -> None: """Test async_get_speech_to_text_engine.""" mock_stt_platform( @@ -614,7 +545,7 @@ async def test_get_engine_legacy( async def test_get_engine_entity( - hass: HomeAssistant, tmp_path: Path, mock_provider_entity: MockProviderEntity + hass: HomeAssistant, tmp_path: Path, mock_provider_entity: MockSTTProviderEntity ) -> None: """Test async_get_speech_to_text_engine.""" await mock_config_entry_setup(hass, tmp_path, mock_provider_entity)