diff --git a/homeassistant/components/assist_pipeline/pipeline.py b/homeassistant/components/assist_pipeline/pipeline.py index d347e433f46..6a4bbdf61e6 100644 --- a/homeassistant/components/assist_pipeline/pipeline.py +++ b/homeassistant/components/assist_pipeline/pipeline.py @@ -5,7 +5,7 @@ import asyncio from collections.abc import AsyncIterable, Callable, Iterable from dataclasses import asdict, dataclass, field import logging -from typing import Any +from typing import Any, cast import voluptuous as vol @@ -332,12 +332,12 @@ class PipelineRun: event_callback: PipelineEventCallback language: str = None # type: ignore[assignment] runner_data: Any | None = None - stt_provider: stt.SpeechToTextEntity | stt.Provider | None = None intent_agent: str | None = None - tts_engine: str | None = None tts_audio_output: str | None = None id: str = field(default_factory=ulid_util.ulid) + stt_provider: stt.SpeechToTextEntity | stt.Provider = field(init=False) + tts_engine: str = field(init=False) tts_options: dict | None = field(init=False, default=None) def __post_init__(self) -> None: @@ -388,8 +388,6 @@ class PipelineRun: async def prepare_speech_to_text(self, metadata: stt.SpeechMetadata) -> None: """Prepare speech to text.""" - stt_provider: stt.SpeechToTextEntity | stt.Provider | None = None - # pipeline.stt_engine can't be None or this function is not called stt_provider = stt.async_get_speech_to_text_engine( self.hass, @@ -422,9 +420,6 @@ class PipelineRun: stream: AsyncIterable[bytes], ) -> str: """Run speech to text portion of pipeline. Returns the spoken text.""" - if self.stt_provider is None: - raise RuntimeError("Speech to text was not prepared") - if isinstance(self.stt_provider, stt.Provider): engine = self.stt_provider.name else: @@ -547,7 +542,8 @@ class PipelineRun: async def prepare_text_to_speech(self) -> None: """Prepare text to speech.""" - engine = self.pipeline.tts_engine + # pipeline.tts_engine can't be None or this function is not called + engine = cast(str, self.pipeline.tts_engine) tts_options = {} if self.pipeline.tts_voice is not None: @@ -557,34 +553,31 @@ class PipelineRun: tts_options[tts.ATTR_AUDIO_OUTPUT] = self.tts_audio_output try: - # pipeline.tts_engine can't be None or this function is not called - if not await tts.async_support_options( + options_supported = await tts.async_support_options( self.hass, - engine, # type: ignore[arg-type] + engine, self.pipeline.tts_language, tts_options, - ): - raise TextToSpeechError( - code="tts-not-supported", - message=( - f"Text to speech engine {engine} " - f"does not support language {self.pipeline.tts_language} or options {tts_options}" - ), - ) + ) except HomeAssistantError as err: raise TextToSpeechError( code="tts-not-supported", message=f"Text to speech engine '{engine}' not found", ) from err + if not options_supported: + raise TextToSpeechError( + code="tts-not-supported", + message=( + f"Text to speech engine {engine} " + f"does not support language {self.pipeline.tts_language} or options {tts_options}" + ), + ) self.tts_engine = engine self.tts_options = tts_options async def text_to_speech(self, tts_input: str) -> str: """Run text to speech portion of pipeline. Returns URL of TTS audio.""" - if self.tts_engine is None: - raise RuntimeError("Text to speech was not prepared") - self.process_event( PipelineEvent( PipelineEventType.TTS_START, diff --git a/tests/components/assist_pipeline/test_init.py b/tests/components/assist_pipeline/test_init.py index 6fb6bf61d96..392363fc0cc 100644 --- a/tests/components/assist_pipeline/test_init.py +++ b/tests/components/assist_pipeline/test_init.py @@ -241,3 +241,42 @@ async def test_pipeline_from_audio_stream_no_stt( ) assert not events + + +async def test_pipeline_from_audio_stream_unknown_pipeline( + hass: HomeAssistant, + hass_ws_client: WebSocketGenerator, + mock_stt_provider: MockSttProvider, + init_components, + snapshot: SnapshotAssertion, +) -> None: + """Test creating a pipeline from an audio stream. + + In this test, the pipeline does not exist. + """ + events = [] + + async def audio_data(): + yield b"part1" + yield b"part2" + yield b"" + + # Try to use the created pipeline + with pytest.raises(assist_pipeline.PipelineNotFound): + await assist_pipeline.async_pipeline_from_audio_stream( + hass, + Context(), + events.append, + stt.SpeechMetadata( + language="en-UK", + format=stt.AudioFormats.WAV, + codec=stt.AudioCodecs.PCM, + bit_rate=stt.AudioBitRates.BITRATE_16, + sample_rate=stt.AudioSampleRates.SAMPLERATE_16000, + channel=stt.AudioChannels.CHANNEL_MONO, + ), + audio_data(), + pipeline_id="blah", + ) + + assert not events diff --git a/tests/components/assist_pipeline/test_websocket.py b/tests/components/assist_pipeline/test_websocket.py index c71d0526fe6..95e2c33ef5b 100644 --- a/tests/components/assist_pipeline/test_websocket.py +++ b/tests/components/assist_pipeline/test_websocket.py @@ -7,6 +7,7 @@ from syrupy.assertion import SnapshotAssertion from homeassistant.components.assist_pipeline.const import DOMAIN from homeassistant.components.assist_pipeline.pipeline import Pipeline, PipelineData from homeassistant.core import HomeAssistant +from homeassistant.exceptions import HomeAssistantError from tests.typing import WebSocketGenerator @@ -430,6 +431,34 @@ async def test_stt_provider_missing( assert msg["error"]["code"] == "stt-provider-missing" +async def test_stt_provider_bad_metadata( + hass: HomeAssistant, + hass_ws_client: WebSocketGenerator, + init_components, + mock_stt_provider, + snapshot: SnapshotAssertion, +) -> None: + """Test events from a pipeline run with wrong metadata.""" + with patch.object(mock_stt_provider, "check_metadata", return_value=False): + client = await hass_ws_client(hass) + + await client.send_json_auto_id( + { + "type": "assist_pipeline/run", + "start_stage": "stt", + "end_stage": "tts", + "input": { + "sample_rate": 12345, + }, + } + ) + + # result + msg = await client.receive_json() + assert not msg["success"] + assert msg["error"]["code"] == "stt-provider-unsupported-metadata" + + async def test_stt_stream_failed( hass: HomeAssistant, hass_ws_client: WebSocketGenerator, @@ -559,6 +588,64 @@ async def test_tts_failed( assert msg["result"] == {"events": events} +async def test_tts_provider_missing( + hass: HomeAssistant, + hass_ws_client: WebSocketGenerator, + init_components, + mock_tts_provider, + snapshot: SnapshotAssertion, +) -> None: + """Test pipeline run with text to speech error.""" + client = await hass_ws_client(hass) + + with patch( + "homeassistant.components.tts.async_support_options", + side_effect=HomeAssistantError, + ): + await client.send_json_auto_id( + { + "type": "assist_pipeline/run", + "start_stage": "tts", + "end_stage": "tts", + "input": {"text": "Lights are on."}, + } + ) + + # result + msg = await client.receive_json() + assert not msg["success"] + assert msg["error"]["code"] == "tts-not-supported" + + +async def test_tts_provider_bad_options( + hass: HomeAssistant, + hass_ws_client: WebSocketGenerator, + init_components, + mock_tts_provider, + snapshot: SnapshotAssertion, +) -> None: + """Test pipeline run with text to speech error.""" + client = await hass_ws_client(hass) + + with patch( + "homeassistant.components.tts.async_support_options", + return_value=False, + ): + await client.send_json_auto_id( + { + "type": "assist_pipeline/run", + "start_stage": "tts", + "end_stage": "tts", + "input": {"text": "Lights are on."}, + } + ) + + # result + msg = await client.receive_json() + assert not msg["success"] + assert msg["error"]["code"] == "tts-not-supported" + + async def test_invalid_stage_order( hass: HomeAssistant, hass_ws_client: WebSocketGenerator, init_components ) -> None: