Don't resolve default tts engine in assist pipelines (#91943)

* Don't resolve default tts engine in assist pipelines

* Set tts engine when creating default pipeline

* Update tests
This commit is contained in:
Erik Montnemery 2023-04-24 14:40:11 +02:00 committed by GitHub
parent 392a9f32c9
commit 1c3e1d2e13
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 58 additions and 38 deletions

View File

@ -15,6 +15,7 @@ from homeassistant.components.tts.media_source import (
generate_media_source_id as tts_generate_media_source_id,
)
from homeassistant.core import Context, HomeAssistant, callback
from homeassistant.exceptions import HomeAssistantError
from homeassistant.helpers.collection import (
CollectionError,
ItemNotFound,
@ -85,6 +86,8 @@ async def async_get_pipeline(
# configured language
stt_engine = stt.async_default_provider(hass)
stt_language = hass.config.language if stt_engine else None
tts_engine = tts.async_default_engine(hass)
tts_language = hass.config.language if tts_engine else None
return await pipeline_data.pipeline_store.async_create_item(
{
"conversation_engine": None,
@ -93,8 +96,8 @@ async def async_get_pipeline(
"name": hass.config.language,
"stt_engine": stt_engine,
"stt_language": stt_language,
"tts_engine": None,
"tts_language": None,
"tts_engine": tts_engine,
"tts_language": tts_language,
"tts_voice": None,
}
)
@ -420,14 +423,7 @@ class PipelineRun:
async def prepare_text_to_speech(self) -> None:
"""Prepare text to speech."""
engine = tts.async_resolve_engine(self.hass, self.pipeline.tts_engine)
if engine is None:
engine = self.pipeline.tts_engine or "default"
raise TextToSpeechError(
code="tts-not-supported",
message=f"Text to speech engine '{engine}' not found",
)
engine = self.pipeline.tts_engine
tts_options = {}
if self.pipeline.tts_voice is not None:
@ -436,19 +432,26 @@ class PipelineRun:
if self.tts_audio_output is not None:
tts_options[tts.ATTR_AUDIO_OUTPUT] = self.tts_audio_output
if not await tts.async_support_options(
self.hass,
engine,
self.pipeline.tts_language,
tts_options,
):
try:
# pipeline.tts_engine can't be None or this function is not called
if not await tts.async_support_options(
self.hass,
engine, # type: ignore[arg-type]
self.pipeline.tts_language,
tts_options,
):
raise TextToSpeechError(
code="tts-not-supported",
message=(
f"Text to speech engine {engine} "
f"does not support language {self.pipeline.tts_language} or options {tts_options}"
),
)
except HomeAssistantError as err:
raise TextToSpeechError(
code="tts-not-supported",
message=(
f"Text to speech engine {engine} "
f"does not support language {self.pipeline.tts_language} or options {tts_options}"
),
)
message=f"Text to speech engine '{engine}' not found",
) from err
self.tts_engine = engine
self.tts_options = tts_options
@ -596,6 +599,11 @@ class PipelineInput:
raise PipelineRunValidationError(
"tts_input is required for text to speech"
)
if self.run.end_stage == PipelineStage.TTS:
if self.run.pipeline.tts_engine is None:
raise PipelineRunValidationError(
"the pipeline does not support text to speech"
)
start_stage_index = PIPELINE_STAGE_ORDER.index(self.run.start_stage)

View File

@ -69,8 +69,8 @@ from .media_source import generate_media_source_id, media_source_id_to_kwargs
from .models import Voice
__all__ = [
"async_default_engine",
"async_get_media_source_audio",
"async_resolve_engine",
"async_support_options",
"ATTR_AUDIO_OUTPUT",
"CONF_LANG",
@ -116,6 +116,26 @@ class TTSCache(TypedDict):
pending: asyncio.Task | None
@callback
def async_default_engine(hass: HomeAssistant) -> str | None:
"""Return the domain or entity id of the default engine.
Returns None if no engines found.
"""
component: EntityComponent[TextToSpeechEntity] = hass.data[DOMAIN]
manager: SpeechManager = hass.data[DATA_TTS_MANAGER]
if "cloud" in manager.providers:
return "cloud"
entity = next(iter(component.entities), None)
if entity is not None:
return entity.entity_id
return next(iter(manager.providers), None)
@callback
def async_resolve_engine(hass: HomeAssistant, engine: str | None) -> str | None:
"""Resolve engine.
@ -130,15 +150,7 @@ def async_resolve_engine(hass: HomeAssistant, engine: str | None) -> str | None:
return None
return engine
if "cloud" in manager.providers:
return "cloud"
entity = next(iter(component.entities), None)
if entity is not None:
return entity.entity_id
return next(iter(manager.providers), None)
return async_default_engine(hass)
async def async_support_options(

View File

@ -64,7 +64,7 @@
dict({
'data': dict({
'engine': 'test',
'language': None,
'language': 'en',
'tts_input': "Sorry, I couldn't understand that",
'voice': None,
}),
@ -73,7 +73,7 @@
dict({
'data': dict({
'tts_output': dict({
'media_id': "media-source://tts/test?message=Sorry,+I+couldn't+understand+that",
'media_id': "media-source://tts/test?message=Sorry,+I+couldn't+understand+that&language=en",
'mime_type': 'audio/mpeg',
'url': '/api/tts_proxy/dae2cdcb27a1d1c3b07ba2c7db91480f9d4bfd8f_en-us_-_test.mp3',
}),

View File

@ -61,7 +61,7 @@
# name: test_audio_pipeline.5
dict({
'engine': 'test',
'language': None,
'language': 'en',
'tts_input': "Sorry, I couldn't understand that",
'voice': None,
})
@ -69,7 +69,7 @@
# name: test_audio_pipeline.6
dict({
'tts_output': dict({
'media_id': "media-source://tts/test?message=Sorry,+I+couldn't+understand+that",
'media_id': "media-source://tts/test?message=Sorry,+I+couldn't+understand+that&language=en",
'mime_type': 'audio/mpeg',
'url': '/api/tts_proxy/dae2cdcb27a1d1c3b07ba2c7db91480f9d4bfd8f_en-us_-_test.mp3',
}),
@ -137,7 +137,7 @@
# name: test_audio_pipeline_debug.5
dict({
'engine': 'test',
'language': None,
'language': 'en',
'tts_input': "Sorry, I couldn't understand that",
'voice': None,
})
@ -145,7 +145,7 @@
# name: test_audio_pipeline_debug.6
dict({
'tts_output': dict({
'media_id': "media-source://tts/test?message=Sorry,+I+couldn't+understand+that",
'media_id': "media-source://tts/test?message=Sorry,+I+couldn't+understand+that&language=en",
'mime_type': 'audio/mpeg',
'url': '/api/tts_proxy/dae2cdcb27a1d1c3b07ba2c7db91480f9d4bfd8f_en-us_-_test.mp3',
}),
@ -295,7 +295,7 @@
# name: test_tts_failed.1
dict({
'engine': 'test',
'language': None,
'language': 'en',
'tts_input': 'Lights are on.',
'voice': None,
})