From 01a05340c69cb3f5a6159e520b19bef65a13c5e7 Mon Sep 17 00:00:00 2001 From: Paulus Schoutsen Date: Fri, 31 Mar 2023 15:04:22 -0400 Subject: [PATCH] Voice Assistant: improve error handling (#90541) Co-authored-by: Michael Hansen --- homeassistant/components/stt/__init__.py | 16 +- .../components/voice_assistant/pipeline.py | 169 ++++++++++++------ .../voice_assistant/websocket_api.py | 53 +++--- .../snapshots/test_websocket.ambr | 25 ++- .../voice_assistant/test_websocket.py | 29 +-- 5 files changed, 178 insertions(+), 114 deletions(-) diff --git a/homeassistant/components/stt/__init__.py b/homeassistant/components/stt/__init__.py index 63199402194..b858cc743a2 100644 --- a/homeassistant/components/stt/__init__.py +++ b/homeassistant/components/stt/__init__.py @@ -36,12 +36,20 @@ _LOGGER = logging.getLogger(__name__) @callback -def async_get_provider(hass: HomeAssistant, domain: str | None = None) -> Provider: +def async_get_provider( + hass: HomeAssistant, domain: str | None = None +) -> Provider | None: """Return provider.""" - if domain is None: - domain = next(iter(hass.data[DOMAIN])) + if domain: + return hass.data[DOMAIN].get(domain) - return hass.data[DOMAIN][domain] + if not hass.data[DOMAIN]: + return None + + if "cloud" in hass.data[DOMAIN]: + return hass.data[DOMAIN]["cloud"] + + return next(iter(hass.data[DOMAIN].values())) async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool: diff --git a/homeassistant/components/voice_assistant/pipeline.py b/homeassistant/components/voice_assistant/pipeline.py index 806a603f5e5..ef13d54e6a1 100644 --- a/homeassistant/components/voice_assistant/pipeline.py +++ b/homeassistant/components/voice_assistant/pipeline.py @@ -8,7 +8,7 @@ import logging from typing import Any from homeassistant.backports.enum import StrEnum -from homeassistant.components import conversation, media_source, stt +from homeassistant.components import conversation, media_source, stt, tts from homeassistant.components.tts.media_source import ( generate_media_source_id as tts_generate_media_source_id, ) @@ -17,8 +17,6 @@ from homeassistant.util.dt import utcnow from .const import DOMAIN -DEFAULT_TIMEOUT = 30 # seconds - _LOGGER = logging.getLogger(__name__) @@ -151,6 +149,9 @@ class PipelineRun: event_callback: Callable[[PipelineEvent], None] language: str = None # type: ignore[assignment] runner_data: Any | None = None + stt_provider: stt.Provider | None = None + intent_agent: str | None = None + tts_engine: str | None = None def __post_init__(self): """Set language for pipeline.""" @@ -181,13 +182,39 @@ class PipelineRun: ) ) + async def prepare_speech_to_text(self, metadata: stt.SpeechMetadata) -> None: + """Prepare speech to text.""" + stt_provider = stt.async_get_provider(self.hass, self.pipeline.stt_engine) + + if stt_provider is None: + engine = self.pipeline.stt_engine or "default" + raise SpeechToTextError( + code="stt-provider-missing", + message=f"No speech to text provider for: {engine}", + ) + + if not stt_provider.check_metadata(metadata): + raise SpeechToTextError( + code="stt-provider-unsupported-metadata", + message=( + f"Provider {engine} does not support input speech " + "to text metadata" + ), + ) + + self.stt_provider = stt_provider + async def speech_to_text( self, metadata: stt.SpeechMetadata, stream: AsyncIterable[bytes], ) -> str: """Run speech to text portion of pipeline. Returns the spoken text.""" - engine = self.pipeline.stt_engine or "default" + if self.stt_provider is None: + raise RuntimeError("Speech to text was not prepared") + + engine = self.stt_provider.name + self.event_callback( PipelineEvent( PipelineEventType.STT_START, @@ -198,28 +225,11 @@ class PipelineRun: ) ) - try: - # Load provider - stt_provider: stt.Provider = stt.async_get_provider( - self.hass, self.pipeline.stt_engine - ) - assert stt_provider is not None - except Exception as src_error: - _LOGGER.exception("No speech to text provider for %s", engine) - raise SpeechToTextError( - code="stt-provider-missing", - message=f"No speech to text provider for: {engine}", - ) from src_error - - if not stt_provider.check_metadata(metadata): - raise SpeechToTextError( - code="stt-provider-unsupported-metadata", - message=f"Provider {engine} does not support input speech to text metadata", - ) - try: # Transcribe audio stream - result = await stt_provider.async_process_audio_stream(metadata, stream) + result = await self.stt_provider.async_process_audio_stream( + metadata, stream + ) except Exception as src_error: _LOGGER.exception("Unexpected error during speech to text") raise SpeechToTextError( @@ -253,15 +263,33 @@ class PipelineRun: return result.text + async def prepare_recognize_intent(self) -> None: + """Prepare recognizing an intent.""" + agent_info = conversation.async_get_agent_info( + self.hass, self.pipeline.conversation_engine + ) + + if agent_info is None: + engine = self.pipeline.conversation_engine or "default" + raise IntentRecognitionError( + code="intent-not-supported", + message=f"Intent recognition engine {engine} is not found", + ) + + self.intent_agent = agent_info["id"] + async def recognize_intent( self, intent_input: str, conversation_id: str | None ) -> str: """Run intent recognition portion of pipeline. Returns text to speak.""" + if self.intent_agent is None: + raise RuntimeError("Recognize intent was not prepared") + self.event_callback( PipelineEvent( PipelineEventType.INTENT_START, { - "engine": self.pipeline.conversation_engine or "default", + "engine": self.intent_agent, "intent_input": intent_input, }, ) @@ -274,7 +302,7 @@ class PipelineRun: conversation_id=conversation_id, context=self.context, language=self.language, - agent_id=self.pipeline.conversation_engine, + agent_id=self.intent_agent, ) except Exception as src_error: _LOGGER.exception("Unexpected error during intent recognition") @@ -296,13 +324,38 @@ class PipelineRun: return speech + async def prepare_text_to_speech(self) -> None: + """Prepare text to speech.""" + engine = tts.async_resolve_engine(self.hass, self.pipeline.tts_engine) + + if engine is None: + engine = self.pipeline.tts_engine or "default" + raise TextToSpeechError( + code="tts-not-supported", + message=f"Text to speech engine '{engine}' not found", + ) + + if not await tts.async_support_options(self.hass, engine, self.language): + raise TextToSpeechError( + code="tts-not-supported", + message=( + f"Text to speech engine {engine} " + f"does not support language {self.language}" + ), + ) + + self.tts_engine = engine + async def text_to_speech(self, tts_input: str) -> str: """Run text to speech portion of pipeline. Returns URL of TTS audio.""" + if self.tts_engine is None: + raise RuntimeError("Text to speech was not prepared") + self.event_callback( PipelineEvent( PipelineEventType.TTS_START, { - "engine": self.pipeline.tts_engine or "default", + "engine": self.tts_engine, "tts_input": tts_input, }, ) @@ -315,7 +368,8 @@ class PipelineRun: tts_generate_media_source_id( self.hass, tts_input, - engine=self.pipeline.tts_engine, + engine=self.tts_engine, + language=self.language, ), ) except Exception as src_error: @@ -341,6 +395,8 @@ class PipelineRun: class PipelineInput: """Input to a pipeline run.""" + run: PipelineRun + stt_metadata: stt.SpeechMetadata | None = None """Metadata of stt input audio. Required when start_stage = stt.""" @@ -355,21 +411,10 @@ class PipelineInput: conversation_id: str | None = None - async def execute( - self, run: PipelineRun, timeout: int | float | None = DEFAULT_TIMEOUT - ): - """Run pipeline with optional timeout.""" - await asyncio.wait_for( - self._execute(run), - timeout=timeout, - ) - - async def _execute(self, run: PipelineRun): - self._validate(run.start_stage) - - # stt -> intent -> tts - run.start() - current_stage = run.start_stage + async def execute(self): + """Run pipeline.""" + self.run.start() + current_stage = self.run.start_stage try: # Speech to text @@ -377,29 +422,29 @@ class PipelineInput: if current_stage == PipelineStage.STT: assert self.stt_metadata is not None assert self.stt_stream is not None - intent_input = await run.speech_to_text( + intent_input = await self.run.speech_to_text( self.stt_metadata, self.stt_stream, ) current_stage = PipelineStage.INTENT - if run.end_stage != PipelineStage.STT: + if self.run.end_stage != PipelineStage.STT: tts_input = self.tts_input if current_stage == PipelineStage.INTENT: assert intent_input is not None - tts_input = await run.recognize_intent( + tts_input = await self.run.recognize_intent( intent_input, self.conversation_id ) current_stage = PipelineStage.TTS - if run.end_stage != PipelineStage.INTENT: + if self.run.end_stage != PipelineStage.INTENT: if current_stage == PipelineStage.TTS: assert tts_input is not None - await run.text_to_speech(tts_input) + await self.run.text_to_speech(tts_input) except PipelineError as err: - run.event_callback( + self.run.event_callback( PipelineEvent( PipelineEventType.ERROR, {"code": err.code, "message": err.message}, @@ -407,11 +452,11 @@ class PipelineInput: ) return - run.end() + self.run.end() - def _validate(self, stage: PipelineStage): + async def validate(self): """Validate pipeline input against start stage.""" - if stage == PipelineStage.STT: + if self.run.start_stage == PipelineStage.STT: if self.stt_metadata is None: raise PipelineRunValidationError( "stt_metadata is required for speech to text" @@ -421,13 +466,29 @@ class PipelineInput: raise PipelineRunValidationError( "stt_stream is required for speech to text" ) - elif stage == PipelineStage.INTENT: + elif self.run.start_stage == PipelineStage.INTENT: if self.intent_input is None: raise PipelineRunValidationError( "intent_input is required for intent recognition" ) - elif stage == PipelineStage.TTS: + elif self.run.start_stage == PipelineStage.TTS: if self.tts_input is None: raise PipelineRunValidationError( "tts_input is required for text to speech" ) + + start_stage_index = PIPELINE_STAGE_ORDER.index(self.run.start_stage) + + prepare_tasks = [] + + if start_stage_index <= PIPELINE_STAGE_ORDER.index(PipelineStage.STT): + prepare_tasks.append(self.run.prepare_speech_to_text(self.stt_metadata)) + + if start_stage_index <= PIPELINE_STAGE_ORDER.index(PipelineStage.INTENT): + prepare_tasks.append(self.run.prepare_recognize_intent()) + + if start_stage_index <= PIPELINE_STAGE_ORDER.index(PipelineStage.TTS): + prepare_tasks.append(self.run.prepare_text_to_speech()) + + if prepare_tasks: + await asyncio.gather(*prepare_tasks) diff --git a/homeassistant/components/voice_assistant/websocket_api.py b/homeassistant/components/voice_assistant/websocket_api.py index 28cafb7a355..aa295ad5c62 100644 --- a/homeassistant/components/voice_assistant/websocket_api.py +++ b/homeassistant/components/voice_assistant/websocket_api.py @@ -5,13 +5,13 @@ from collections.abc import Callable import logging from typing import Any +import async_timeout import voluptuous as vol from homeassistant.components import stt, websocket_api from homeassistant.core import HomeAssistant, callback from .pipeline import ( - DEFAULT_TIMEOUT, PipelineError, PipelineEvent, PipelineEventType, @@ -21,6 +21,8 @@ from .pipeline import ( async_get_pipeline, ) +DEFAULT_TIMEOUT = 30 + _LOGGER = logging.getLogger(__name__) _VAD_ENERGY_THRESHOLD = 1000 @@ -155,37 +157,40 @@ async def websocket_run( # Input to text to speech system input_args["tts_input"] = msg["input"]["text"] - run_task = hass.async_create_task( - PipelineInput(**input_args).execute( - PipelineRun( - hass, - context=connection.context(msg), - pipeline=pipeline, - start_stage=start_stage, - end_stage=end_stage, - event_callback=lambda event: connection.send_event( - msg["id"], event.as_dict() - ), - runner_data={ - "stt_binary_handler_id": handler_id, - }, - ), - timeout=timeout, - ) + input_args["run"] = PipelineRun( + hass, + context=connection.context(msg), + pipeline=pipeline, + start_stage=start_stage, + end_stage=end_stage, + event_callback=lambda event: connection.send_event(msg["id"], event.as_dict()), + runner_data={ + "stt_binary_handler_id": handler_id, + "timeout": timeout, + }, ) - # Cancel pipeline if user unsubscribes - connection.subscriptions[msg["id"]] = run_task.cancel + pipeline_input = PipelineInput(**input_args) + + try: + await pipeline_input.validate() + except PipelineError as error: + # Report more specific error when possible + connection.send_error(msg["id"], error.code, error.message) + return # Confirm subscription connection.send_result(msg["id"]) + run_task = hass.async_create_task(pipeline_input.execute()) + + # Cancel pipeline if user unsubscribes + connection.subscriptions[msg["id"]] = run_task.cancel + try: # Task contains a timeout - await run_task - except PipelineError as error: - # Report more specific error when possible - connection.send_error(msg["id"], error.code, error.message) + async with async_timeout.timeout(timeout): + await run_task except asyncio.TimeoutError: connection.send_event( msg["id"], diff --git a/tests/components/voice_assistant/snapshots/test_websocket.ambr b/tests/components/voice_assistant/snapshots/test_websocket.ambr index c18af44b21c..a5812d170f6 100644 --- a/tests/components/voice_assistant/snapshots/test_websocket.ambr +++ b/tests/components/voice_assistant/snapshots/test_websocket.ambr @@ -5,12 +5,13 @@ 'pipeline': 'en-US', 'runner_data': dict({ 'stt_binary_handler_id': 1, + 'timeout': 30, }), }) # --- # name: test_audio_pipeline.1 dict({ - 'engine': 'default', + 'engine': 'test', 'metadata': dict({ 'bit_rate': 16, 'channel': 1, @@ -30,7 +31,7 @@ # --- # name: test_audio_pipeline.3 dict({ - 'engine': 'default', + 'engine': 'homeassistant', 'intent_input': 'test transcript', }) # --- @@ -58,7 +59,7 @@ # --- # name: test_audio_pipeline.5 dict({ - 'engine': 'default', + 'engine': 'test', 'tts_input': "Sorry, I couldn't understand that", }) # --- @@ -66,7 +67,7 @@ dict({ 'tts_output': dict({ 'mime_type': 'audio/mpeg', - 'url': '/api/tts_proxy/dae2cdcb27a1d1c3b07ba2c7db91480f9d4bfd8f_en_-_test.mp3', + 'url': '/api/tts_proxy/dae2cdcb27a1d1c3b07ba2c7db91480f9d4bfd8f_en-us_-_test.mp3', }), }) # --- @@ -76,12 +77,13 @@ 'pipeline': 'en-US', 'runner_data': dict({ 'stt_binary_handler_id': None, + 'timeout': 30, }), }) # --- # name: test_intent_failed.1 dict({ - 'engine': 'default', + 'engine': 'homeassistant', 'intent_input': 'Are the lights on?', }) # --- @@ -91,12 +93,13 @@ 'pipeline': 'en-US', 'runner_data': dict({ 'stt_binary_handler_id': None, + 'timeout': 0.1, }), }) # --- # name: test_intent_timeout.1 dict({ - 'engine': 'default', + 'engine': 'homeassistant', 'intent_input': 'Are the lights on?', }) # --- @@ -112,6 +115,7 @@ 'pipeline': 'en-US', 'runner_data': dict({ 'stt_binary_handler_id': 1, + 'timeout': 30, }), }) # --- @@ -134,12 +138,13 @@ 'pipeline': 'en-US', 'runner_data': dict({ 'stt_binary_handler_id': 1, + 'timeout': 30, }), }) # --- # name: test_stt_stream_failed.1 dict({ - 'engine': 'default', + 'engine': 'test', 'metadata': dict({ 'bit_rate': 16, 'channel': 1, @@ -156,12 +161,13 @@ 'pipeline': 'en-US', 'runner_data': dict({ 'stt_binary_handler_id': None, + 'timeout': 30, }), }) # --- # name: test_text_only_pipeline.1 dict({ - 'engine': 'default', + 'engine': 'homeassistant', 'intent_input': 'Are the lights on?', }) # --- @@ -199,12 +205,13 @@ 'pipeline': 'en-US', 'runner_data': dict({ 'stt_binary_handler_id': None, + 'timeout': 30, }), }) # --- # name: test_tts_failed.1 dict({ - 'engine': 'default', + 'engine': 'test', 'tts_input': 'Lights are on.', }) # --- diff --git a/tests/components/voice_assistant/test_websocket.py b/tests/components/voice_assistant/test_websocket.py index 149d896dcf6..ce876550327 100644 --- a/tests/components/voice_assistant/test_websocket.py +++ b/tests/components/voice_assistant/test_websocket.py @@ -93,7 +93,7 @@ class MockTTSProvider(tts.Provider): @property def supported_languages(self) -> list[str]: """Return list of supported languages.""" - return ["en"] + return ["en-US"] @property def supported_options(self) -> list[str]: @@ -264,7 +264,7 @@ async def test_intent_timeout( "start_stage": "intent", "end_stage": "intent", "input": {"text": "Are the lights on?"}, - "timeout": 0.00001, + "timeout": 0.1, } ) @@ -301,7 +301,7 @@ async def test_text_pipeline_timeout( await asyncio.sleep(3600) with patch( - "homeassistant.components.voice_assistant.pipeline.PipelineInput._execute", + "homeassistant.components.voice_assistant.pipeline.PipelineInput.execute", new=sleepy_run, ): await client.send_json( @@ -381,7 +381,7 @@ async def test_audio_pipeline_timeout( await asyncio.sleep(3600) with patch( - "homeassistant.components.voice_assistant.pipeline.PipelineInput._execute", + "homeassistant.components.voice_assistant.pipeline.PipelineInput.execute", new=sleepy_run, ): await client.send_json( @@ -427,25 +427,8 @@ async def test_stt_provider_missing( # result msg = await client.receive_json() - assert msg["success"] - - # run start - msg = await client.receive_json() - assert msg["event"]["type"] == "run-start" - assert msg["event"]["data"] == snapshot - - # stt - msg = await client.receive_json() - assert msg["event"]["type"] == "stt-start" - assert msg["event"]["data"] == snapshot - - # End of audio stream (handler id + empty payload) - await client.send_bytes(b"1") - - # stt error - msg = await client.receive_json() - assert msg["event"]["type"] == "error" - assert msg["event"]["data"]["code"] == "stt-provider-missing" + assert not msg["success"] + assert msg["error"]["code"] == "stt-provider-missing" async def test_stt_stream_failed(