diff --git a/homeassistant/components/assist_pipeline/pipeline.py b/homeassistant/components/assist_pipeline/pipeline.py index 6e69fe0dc2e..ef528bec4ae 100644 --- a/homeassistant/components/assist_pipeline/pipeline.py +++ b/homeassistant/components/assist_pipeline/pipeline.py @@ -83,14 +83,16 @@ async def async_get_pipeline( if pipeline_id is None: # There's no preferred pipeline, construct a pipeline for the # configured language + stt_engine = stt.async_default_provider(hass) + stt_language = hass.config.language if stt_engine else None return await pipeline_data.pipeline_store.async_create_item( { "conversation_engine": None, "conversation_language": None, "language": hass.config.language, "name": hass.config.language, - "stt_engine": None, - "stt_language": None, + "stt_engine": stt_engine, + "stt_language": stt_language, "tts_engine": None, "tts_language": None, "tts_voice": None, @@ -261,22 +263,14 @@ class PipelineRun: """Prepare speech to text.""" stt_provider: stt.SpeechToTextEntity | stt.Provider | None = None - if self.pipeline.stt_engine is not None: - # Try entity first - stt_provider = stt.async_get_speech_to_text_entity( - self.hass, - self.pipeline.stt_engine, - ) + # pipeline.stt_engine can't be None or this function is not called + stt_provider = stt.async_get_speech_to_text_engine( + self.hass, + self.pipeline.stt_engine, # type: ignore[arg-type] + ) if stt_provider is None: - # Try legacy provider second - stt_provider = stt.async_get_provider( - self.hass, - self.pipeline.stt_engine, - ) - - if stt_provider is None: - engine = self.pipeline.stt_engine or "default" + engine = self.pipeline.stt_engine raise SpeechToTextError( code="stt-provider-missing", message=f"No speech to text provider for: {engine}", @@ -580,11 +574,14 @@ class PipelineInput: async def validate(self) -> None: """Validate pipeline input against start stage.""" if self.run.start_stage == PipelineStage.STT: + if self.run.pipeline.stt_engine is None: + raise PipelineRunValidationError( + "the pipeline does not support speech to text" + ) if self.stt_metadata is None: raise PipelineRunValidationError( "stt_metadata is required for speech to text" ) - if self.stt_stream is None: raise PipelineRunValidationError( "stt_stream is required for speech to text" diff --git a/homeassistant/components/stt/__init__.py b/homeassistant/components/stt/__init__.py index 0a7711d5619..e26603a5bb5 100644 --- a/homeassistant/components/stt/__init__.py +++ b/homeassistant/components/stt/__init__.py @@ -41,12 +41,14 @@ from .legacy import ( Provider, SpeechMetadata, SpeechResult, + async_default_provider, async_get_provider, async_setup_legacy, ) __all__ = [ "async_get_provider", + "async_get_speech_to_text_engine", "async_get_speech_to_text_entity", "AudioBitRates", "AudioChannels", @@ -64,6 +66,14 @@ __all__ = [ _LOGGER = logging.getLogger(__name__) +@callback +def async_default_engine(hass: HomeAssistant) -> str | None: + """Return the domain or entity id of the default engine.""" + return async_default_provider(hass) or next( + iter(hass.states.async_entity_ids(DOMAIN)), None + ) + + @callback def async_get_speech_to_text_entity( hass: HomeAssistant, entity_id: str @@ -74,6 +84,16 @@ def async_get_speech_to_text_entity( return component.get_entity(entity_id) +@callback +def async_get_speech_to_text_engine( + hass: HomeAssistant, engine_id: str +) -> SpeechToTextEntity | Provider | None: + """Return stt entity or legacy provider.""" + if entity := async_get_speech_to_text_entity(hass, engine_id): + return entity + return async_get_provider(hass, engine_id) + + @callback def async_get_speech_to_text_languages(hass: HomeAssistant) -> set[str]: """Return a set with the union of languages supported by stt engines.""" diff --git a/homeassistant/components/stt/legacy.py b/homeassistant/components/stt/legacy.py index be8429b9a4f..7c126849690 100644 --- a/homeassistant/components/stt/legacy.py +++ b/homeassistant/components/stt/legacy.py @@ -27,6 +27,15 @@ from .const import ( _LOGGER = logging.getLogger(__name__) +@callback +def async_default_provider(hass: HomeAssistant) -> str | None: + """Return the domain of the default provider.""" + if "cloud" in hass.data[DATA_PROVIDERS]: + return "cloud" + + return next(iter(hass.data[DATA_PROVIDERS]), None) + + @callback def async_get_provider( hass: HomeAssistant, domain: str | None = None @@ -35,13 +44,8 @@ def async_get_provider( if domain: return hass.data[DATA_PROVIDERS].get(domain) - if not hass.data[DATA_PROVIDERS]: - return None - - if "cloud" in hass.data[DATA_PROVIDERS]: - return hass.data[DATA_PROVIDERS]["cloud"] - - return next(iter(hass.data[DATA_PROVIDERS].values())) + provider = async_default_provider(hass) + return hass.data[DATA_PROVIDERS][provider] if provider is not None else None @callback diff --git a/tests/components/assist_pipeline/test_init.py b/tests/components/assist_pipeline/test_init.py index 6b8ab8f4e96..cbab6875b10 100644 --- a/tests/components/assist_pipeline/test_init.py +++ b/tests/components/assist_pipeline/test_init.py @@ -1,6 +1,7 @@ """Test Voice Assistant init.""" from dataclasses import asdict +import pytest from syrupy.assertion import SnapshotAssertion from homeassistant.components import assist_pipeline, stt @@ -184,3 +185,63 @@ async def test_pipeline_from_audio_stream_entity( assert processed == snapshot assert mock_stt_provider_entity.received == [b"part1", b"part2"] + + +async def test_pipeline_from_audio_stream_no_stt( + hass: HomeAssistant, + hass_ws_client: WebSocketGenerator, + mock_stt_provider: MockSttProvider, + init_components, + snapshot: SnapshotAssertion, +) -> None: + """Test creating a pipeline from an audio stream. + + In this test, the pipeline does not support stt + """ + client = await hass_ws_client(hass) + + events = [] + + async def audio_data(): + yield b"part1" + yield b"part2" + yield b"" + + # Create a pipeline without stt support + await client.send_json_auto_id( + { + "type": "assist_pipeline/pipeline/create", + "conversation_engine": "homeassistant", + "conversation_language": "en-US", + "language": "en", + "name": "test_name", + "stt_engine": None, + "stt_language": None, + "tts_engine": "test", + "tts_language": "en-AU", + "tts_voice": "Arnold Schwarzenegger", + } + ) + msg = await client.receive_json() + assert msg["success"] + pipeline_id = msg["result"]["id"] + + # Try to use the created pipeline + with pytest.raises(assist_pipeline.pipeline.PipelineRunValidationError): + await assist_pipeline.async_pipeline_from_audio_stream( + hass, + Context(), + events.append, + stt.SpeechMetadata( + language="en-UK", + format=stt.AudioFormats.WAV, + codec=stt.AudioCodecs.PCM, + bit_rate=stt.AudioBitRates.BITRATE_16, + sample_rate=stt.AudioSampleRates.SAMPLERATE_16000, + channel=stt.AudioChannels.CHANNEL_MONO, + ), + audio_data(), + pipeline_id=pipeline_id, + ) + + assert not events diff --git a/tests/components/stt/test_init.py b/tests/components/stt/test_init.py index e7e0decde72..5706e54d2f0 100644 --- a/tests/components/stt/test_init.py +++ b/tests/components/stt/test_init.py @@ -18,7 +18,9 @@ from homeassistant.components.stt import ( SpeechResult, SpeechResultState, SpeechToTextEntity, + async_default_engine, async_get_provider, + async_get_speech_to_text_engine, ) from homeassistant.config_entries import ConfigEntry, ConfigEntryState, ConfigFlow from homeassistant.core import HomeAssistant, State @@ -349,6 +351,9 @@ async def test_get_provider( await mock_setup(hass, tmp_path, mock_provider) assert mock_provider == async_get_provider(hass, TEST_DOMAIN) + # Test getting the default provider + assert mock_provider == async_get_provider(hass) + async def test_config_entry_unload( hass: HomeAssistant, tmp_path: Path, mock_provider_entity: MockProviderEntity @@ -444,3 +449,84 @@ async def test_ws_list_engines( assert msg["result"] == { "providers": [{"engine_id": engine_id, "supported_languages": ["de-CH", "de"]}] } + + +async def test_default_engine_none(hass: HomeAssistant, tmp_path: Path) -> None: + """Test async_default_engine.""" + assert await async_setup_component(hass, "stt", {"stt": {}}) + await hass.async_block_till_done() + + assert async_default_engine(hass) is None + + +async def test_default_engine(hass: HomeAssistant, tmp_path: Path) -> None: + """Test async_default_engine.""" + mock_stt_platform( + hass, + tmp_path, + TEST_DOMAIN, + async_get_engine=AsyncMock(return_value=mock_provider), + ) + assert await async_setup_component(hass, "stt", {"stt": {"platform": TEST_DOMAIN}}) + await hass.async_block_till_done() + + assert async_default_engine(hass) == TEST_DOMAIN + + +async def test_default_engine_entity( + hass: HomeAssistant, tmp_path: Path, mock_provider_entity: MockProviderEntity +) -> None: + """Test async_default_engine.""" + await mock_config_entry_setup(hass, tmp_path, mock_provider_entity) + + assert async_default_engine(hass) == f"{DOMAIN}.{TEST_DOMAIN}" + + +async def test_default_engine_prefer_cloud(hass: HomeAssistant, tmp_path: Path) -> None: + """Test async_default_engine.""" + mock_stt_platform( + hass, + tmp_path, + TEST_DOMAIN, + async_get_engine=AsyncMock(return_value=mock_provider), + ) + mock_stt_platform( + hass, + tmp_path, + "cloud", + async_get_engine=AsyncMock(return_value=mock_provider), + ) + assert await async_setup_component( + hass, "stt", {"stt": [{"platform": TEST_DOMAIN}, {"platform": "cloud"}]} + ) + await hass.async_block_till_done() + + assert async_default_engine(hass) == "cloud" + + +async def test_get_engine_legacy( + hass: HomeAssistant, tmp_path: Path, mock_provider: MockProvider +) -> None: + """Test async_get_speech_to_text_engine.""" + mock_stt_platform( + hass, + tmp_path, + TEST_DOMAIN, + async_get_engine=AsyncMock(return_value=mock_provider), + ) + assert await async_setup_component( + hass, "stt", {"stt": [{"platform": TEST_DOMAIN}, {"platform": "cloud"}]} + ) + await hass.async_block_till_done() + + assert async_get_speech_to_text_engine(hass, "no_such_provider") is None + assert async_get_speech_to_text_engine(hass, "test") is mock_provider + + +async def test_get_engine_entity( + hass: HomeAssistant, tmp_path: Path, mock_provider_entity: MockProviderEntity +) -> None: + """Test async_get_speech_to_text_engine.""" + await mock_config_entry_setup(hass, tmp_path, mock_provider_entity) + + assert async_get_speech_to_text_engine(hass, "stt.test") is mock_provider_entity