From 1c3e1d2e134e405333ba37f737979ad853dbac76 Mon Sep 17 00:00:00 2001 From: Erik Montnemery Date: Mon, 24 Apr 2023 14:40:11 +0200 Subject: [PATCH] Don't resolve default tts engine in assist pipelines (#91943) * Don't resolve default tts engine in assist pipelines * Set tts engine when creating default pipeline * Update tests --- .../components/assist_pipeline/pipeline.py | 50 +++++++++++-------- homeassistant/components/tts/__init__.py | 32 ++++++++---- .../assist_pipeline/snapshots/test_init.ambr | 4 +- .../snapshots/test_websocket.ambr | 10 ++-- 4 files changed, 58 insertions(+), 38 deletions(-) diff --git a/homeassistant/components/assist_pipeline/pipeline.py b/homeassistant/components/assist_pipeline/pipeline.py index ef528bec4ae..fd0df4edf92 100644 --- a/homeassistant/components/assist_pipeline/pipeline.py +++ b/homeassistant/components/assist_pipeline/pipeline.py @@ -15,6 +15,7 @@ from homeassistant.components.tts.media_source import ( generate_media_source_id as tts_generate_media_source_id, ) from homeassistant.core import Context, HomeAssistant, callback +from homeassistant.exceptions import HomeAssistantError from homeassistant.helpers.collection import ( CollectionError, ItemNotFound, @@ -85,6 +86,8 @@ async def async_get_pipeline( # configured language stt_engine = stt.async_default_provider(hass) stt_language = hass.config.language if stt_engine else None + tts_engine = tts.async_default_engine(hass) + tts_language = hass.config.language if tts_engine else None return await pipeline_data.pipeline_store.async_create_item( { "conversation_engine": None, @@ -93,8 +96,8 @@ async def async_get_pipeline( "name": hass.config.language, "stt_engine": stt_engine, "stt_language": stt_language, - "tts_engine": None, - "tts_language": None, + "tts_engine": tts_engine, + "tts_language": tts_language, "tts_voice": None, } ) @@ -420,14 +423,7 @@ class PipelineRun: async def prepare_text_to_speech(self) -> None: """Prepare text to speech.""" - engine = tts.async_resolve_engine(self.hass, self.pipeline.tts_engine) - - if engine is None: - engine = self.pipeline.tts_engine or "default" - raise TextToSpeechError( - code="tts-not-supported", - message=f"Text to speech engine '{engine}' not found", - ) + engine = self.pipeline.tts_engine tts_options = {} if self.pipeline.tts_voice is not None: @@ -436,19 +432,26 @@ class PipelineRun: if self.tts_audio_output is not None: tts_options[tts.ATTR_AUDIO_OUTPUT] = self.tts_audio_output - if not await tts.async_support_options( - self.hass, - engine, - self.pipeline.tts_language, - tts_options, - ): + try: + # pipeline.tts_engine can't be None or this function is not called + if not await tts.async_support_options( + self.hass, + engine, # type: ignore[arg-type] + self.pipeline.tts_language, + tts_options, + ): + raise TextToSpeechError( + code="tts-not-supported", + message=( + f"Text to speech engine {engine} " + f"does not support language {self.pipeline.tts_language} or options {tts_options}" + ), + ) + except HomeAssistantError as err: raise TextToSpeechError( code="tts-not-supported", - message=( - f"Text to speech engine {engine} " - f"does not support language {self.pipeline.tts_language} or options {tts_options}" - ), - ) + message=f"Text to speech engine '{engine}' not found", + ) from err self.tts_engine = engine self.tts_options = tts_options @@ -596,6 +599,11 @@ class PipelineInput: raise PipelineRunValidationError( "tts_input is required for text to speech" ) + if self.run.end_stage == PipelineStage.TTS: + if self.run.pipeline.tts_engine is None: + raise PipelineRunValidationError( + "the pipeline does not support text to speech" + ) start_stage_index = PIPELINE_STAGE_ORDER.index(self.run.start_stage) diff --git a/homeassistant/components/tts/__init__.py b/homeassistant/components/tts/__init__.py index fa4a989d685..16b78a40dcc 100644 --- a/homeassistant/components/tts/__init__.py +++ b/homeassistant/components/tts/__init__.py @@ -69,8 +69,8 @@ from .media_source import generate_media_source_id, media_source_id_to_kwargs from .models import Voice __all__ = [ + "async_default_engine", "async_get_media_source_audio", - "async_resolve_engine", "async_support_options", "ATTR_AUDIO_OUTPUT", "CONF_LANG", @@ -116,6 +116,26 @@ class TTSCache(TypedDict): pending: asyncio.Task | None +@callback +def async_default_engine(hass: HomeAssistant) -> str | None: + """Return the domain or entity id of the default engine. + + Returns None if no engines found. + """ + component: EntityComponent[TextToSpeechEntity] = hass.data[DOMAIN] + manager: SpeechManager = hass.data[DATA_TTS_MANAGER] + + if "cloud" in manager.providers: + return "cloud" + + entity = next(iter(component.entities), None) + + if entity is not None: + return entity.entity_id + + return next(iter(manager.providers), None) + + @callback def async_resolve_engine(hass: HomeAssistant, engine: str | None) -> str | None: """Resolve engine. @@ -130,15 +150,7 @@ def async_resolve_engine(hass: HomeAssistant, engine: str | None) -> str | None: return None return engine - if "cloud" in manager.providers: - return "cloud" - - entity = next(iter(component.entities), None) - - if entity is not None: - return entity.entity_id - - return next(iter(manager.providers), None) + return async_default_engine(hass) async def async_support_options( diff --git a/tests/components/assist_pipeline/snapshots/test_init.ambr b/tests/components/assist_pipeline/snapshots/test_init.ambr index b5c636b4bd6..efa6434e784 100644 --- a/tests/components/assist_pipeline/snapshots/test_init.ambr +++ b/tests/components/assist_pipeline/snapshots/test_init.ambr @@ -64,7 +64,7 @@ dict({ 'data': dict({ 'engine': 'test', - 'language': None, + 'language': 'en', 'tts_input': "Sorry, I couldn't understand that", 'voice': None, }), @@ -73,7 +73,7 @@ dict({ 'data': dict({ 'tts_output': dict({ - 'media_id': "media-source://tts/test?message=Sorry,+I+couldn't+understand+that", + 'media_id': "media-source://tts/test?message=Sorry,+I+couldn't+understand+that&language=en", 'mime_type': 'audio/mpeg', 'url': '/api/tts_proxy/dae2cdcb27a1d1c3b07ba2c7db91480f9d4bfd8f_en-us_-_test.mp3', }), diff --git a/tests/components/assist_pipeline/snapshots/test_websocket.ambr b/tests/components/assist_pipeline/snapshots/test_websocket.ambr index a77cee41afb..0abb00afdfb 100644 --- a/tests/components/assist_pipeline/snapshots/test_websocket.ambr +++ b/tests/components/assist_pipeline/snapshots/test_websocket.ambr @@ -61,7 +61,7 @@ # name: test_audio_pipeline.5 dict({ 'engine': 'test', - 'language': None, + 'language': 'en', 'tts_input': "Sorry, I couldn't understand that", 'voice': None, }) @@ -69,7 +69,7 @@ # name: test_audio_pipeline.6 dict({ 'tts_output': dict({ - 'media_id': "media-source://tts/test?message=Sorry,+I+couldn't+understand+that", + 'media_id': "media-source://tts/test?message=Sorry,+I+couldn't+understand+that&language=en", 'mime_type': 'audio/mpeg', 'url': '/api/tts_proxy/dae2cdcb27a1d1c3b07ba2c7db91480f9d4bfd8f_en-us_-_test.mp3', }), @@ -137,7 +137,7 @@ # name: test_audio_pipeline_debug.5 dict({ 'engine': 'test', - 'language': None, + 'language': 'en', 'tts_input': "Sorry, I couldn't understand that", 'voice': None, }) @@ -145,7 +145,7 @@ # name: test_audio_pipeline_debug.6 dict({ 'tts_output': dict({ - 'media_id': "media-source://tts/test?message=Sorry,+I+couldn't+understand+that", + 'media_id': "media-source://tts/test?message=Sorry,+I+couldn't+understand+that&language=en", 'mime_type': 'audio/mpeg', 'url': '/api/tts_proxy/dae2cdcb27a1d1c3b07ba2c7db91480f9d4bfd8f_en-us_-_test.mp3', }), @@ -295,7 +295,7 @@ # name: test_tts_failed.1 dict({ 'engine': 'test', - 'language': None, + 'language': 'en', 'tts_input': 'Lights are on.', 'voice': None, })