From a73e86a74160a71354c433acb1c104480dd15e41 Mon Sep 17 00:00:00 2001 From: Michael Hansen Date: Tue, 12 Dec 2023 23:21:16 -0600 Subject: [PATCH] Skip TTS events entirely with empty text (#105617) --- .../components/assist_pipeline/pipeline.py | 60 ++++++++++--------- .../snapshots/test_websocket.ambr | 28 +++++++-- .../assist_pipeline/test_websocket.py | 11 ++-- 3 files changed, 59 insertions(+), 40 deletions(-) diff --git a/homeassistant/components/assist_pipeline/pipeline.py b/homeassistant/components/assist_pipeline/pipeline.py index ed9029d1c2c..26d599da836 100644 --- a/homeassistant/components/assist_pipeline/pipeline.py +++ b/homeassistant/components/assist_pipeline/pipeline.py @@ -369,6 +369,7 @@ class PipelineStage(StrEnum): STT = "stt" INTENT = "intent" TTS = "tts" + END = "end" PIPELINE_STAGE_ORDER = [ @@ -1024,35 +1025,32 @@ class PipelineRun: ) ) - if tts_input := tts_input.strip(): - try: - # Synthesize audio and get URL - tts_media_id = tts_generate_media_source_id( - self.hass, - tts_input, - engine=self.tts_engine, - language=self.pipeline.tts_language, - options=self.tts_options, - ) - tts_media = await media_source.async_resolve_media( - self.hass, - tts_media_id, - None, - ) - except Exception as src_error: - _LOGGER.exception("Unexpected error during text-to-speech") - raise TextToSpeechError( - code="tts-failed", - message="Unexpected error during text-to-speech", - ) from src_error + try: + # Synthesize audio and get URL + tts_media_id = tts_generate_media_source_id( + self.hass, + tts_input, + engine=self.tts_engine, + language=self.pipeline.tts_language, + options=self.tts_options, + ) + tts_media = await media_source.async_resolve_media( + self.hass, + tts_media_id, + None, + ) + except Exception as src_error: + _LOGGER.exception("Unexpected error during text-to-speech") + raise TextToSpeechError( + code="tts-failed", + message="Unexpected error during text-to-speech", + ) from src_error - _LOGGER.debug("TTS result %s", tts_media) - tts_output = { - "media_id": tts_media_id, - **asdict(tts_media), - } - else: - tts_output = {} + _LOGGER.debug("TTS result %s", tts_media) + tts_output = { + "media_id": tts_media_id, + **asdict(tts_media), + } self.process_event( PipelineEvent(PipelineEventType.TTS_END, {"tts_output": tts_output}) @@ -1345,7 +1343,11 @@ class PipelineInput: self.conversation_id, self.device_id, ) - current_stage = PipelineStage.TTS + if tts_input.strip(): + current_stage = PipelineStage.TTS + else: + # Skip TTS + current_stage = PipelineStage.END if self.run.end_stage != PipelineStage.INTENT: # text-to-speech diff --git a/tests/components/assist_pipeline/snapshots/test_websocket.ambr b/tests/components/assist_pipeline/snapshots/test_websocket.ambr index 072b1ff730a..c165675a6ff 100644 --- a/tests/components/assist_pipeline/snapshots/test_websocket.ambr +++ b/tests/components/assist_pipeline/snapshots/test_websocket.ambr @@ -662,15 +662,33 @@ # --- # name: test_pipeline_empty_tts_output.1 dict({ - 'engine': 'test', - 'language': 'en-US', - 'tts_input': '', - 'voice': 'james_earl_jones', + 'conversation_id': None, + 'device_id': None, + 'engine': 'homeassistant', + 'intent_input': 'never mind', + 'language': 'en', }) # --- # name: test_pipeline_empty_tts_output.2 dict({ - 'tts_output': dict({ + 'intent_output': dict({ + 'conversation_id': None, + 'response': dict({ + 'card': dict({ + }), + 'data': dict({ + 'failed': list([ + ]), + 'success': list([ + ]), + 'targets': list([ + ]), + }), + 'language': 'en', + 'response_type': 'action_done', + 'speech': dict({ + }), + }), }), }) # --- diff --git a/tests/components/assist_pipeline/test_websocket.py b/tests/components/assist_pipeline/test_websocket.py index 0e2a3ad538c..458320a9a90 100644 --- a/tests/components/assist_pipeline/test_websocket.py +++ b/tests/components/assist_pipeline/test_websocket.py @@ -2467,10 +2467,10 @@ async def test_pipeline_empty_tts_output( await client.send_json_auto_id( { "type": "assist_pipeline/run", - "start_stage": "tts", + "start_stage": "intent", "end_stage": "tts", "input": { - "text": "", + "text": "never mind", }, } ) @@ -2486,16 +2486,15 @@ async def test_pipeline_empty_tts_output( assert msg["event"]["data"] == snapshot events.append(msg["event"]) - # text-to-speech + # intent msg = await client.receive_json() - assert msg["event"]["type"] == "tts-start" + assert msg["event"]["type"] == "intent-start" assert msg["event"]["data"] == snapshot events.append(msg["event"]) msg = await client.receive_json() - assert msg["event"]["type"] == "tts-end" + assert msg["event"]["type"] == "intent-end" assert msg["event"]["data"] == snapshot - assert not msg["event"]["data"]["tts_output"] events.append(msg["event"]) # run end