mirror of
https://github.com/home-assistant/core.git
synced 2025-07-20 19:57:07 +00:00
Don't resolve default tts engine in assist pipelines (#91943)
* Don't resolve default tts engine in assist pipelines * Set tts engine when creating default pipeline * Update tests
This commit is contained in:
parent
392a9f32c9
commit
1c3e1d2e13
@ -15,6 +15,7 @@ from homeassistant.components.tts.media_source import (
|
||||
generate_media_source_id as tts_generate_media_source_id,
|
||||
)
|
||||
from homeassistant.core import Context, HomeAssistant, callback
|
||||
from homeassistant.exceptions import HomeAssistantError
|
||||
from homeassistant.helpers.collection import (
|
||||
CollectionError,
|
||||
ItemNotFound,
|
||||
@ -85,6 +86,8 @@ async def async_get_pipeline(
|
||||
# configured language
|
||||
stt_engine = stt.async_default_provider(hass)
|
||||
stt_language = hass.config.language if stt_engine else None
|
||||
tts_engine = tts.async_default_engine(hass)
|
||||
tts_language = hass.config.language if tts_engine else None
|
||||
return await pipeline_data.pipeline_store.async_create_item(
|
||||
{
|
||||
"conversation_engine": None,
|
||||
@ -93,8 +96,8 @@ async def async_get_pipeline(
|
||||
"name": hass.config.language,
|
||||
"stt_engine": stt_engine,
|
||||
"stt_language": stt_language,
|
||||
"tts_engine": None,
|
||||
"tts_language": None,
|
||||
"tts_engine": tts_engine,
|
||||
"tts_language": tts_language,
|
||||
"tts_voice": None,
|
||||
}
|
||||
)
|
||||
@ -420,14 +423,7 @@ class PipelineRun:
|
||||
|
||||
async def prepare_text_to_speech(self) -> None:
|
||||
"""Prepare text to speech."""
|
||||
engine = tts.async_resolve_engine(self.hass, self.pipeline.tts_engine)
|
||||
|
||||
if engine is None:
|
||||
engine = self.pipeline.tts_engine or "default"
|
||||
raise TextToSpeechError(
|
||||
code="tts-not-supported",
|
||||
message=f"Text to speech engine '{engine}' not found",
|
||||
)
|
||||
engine = self.pipeline.tts_engine
|
||||
|
||||
tts_options = {}
|
||||
if self.pipeline.tts_voice is not None:
|
||||
@ -436,19 +432,26 @@ class PipelineRun:
|
||||
if self.tts_audio_output is not None:
|
||||
tts_options[tts.ATTR_AUDIO_OUTPUT] = self.tts_audio_output
|
||||
|
||||
if not await tts.async_support_options(
|
||||
self.hass,
|
||||
engine,
|
||||
self.pipeline.tts_language,
|
||||
tts_options,
|
||||
):
|
||||
try:
|
||||
# pipeline.tts_engine can't be None or this function is not called
|
||||
if not await tts.async_support_options(
|
||||
self.hass,
|
||||
engine, # type: ignore[arg-type]
|
||||
self.pipeline.tts_language,
|
||||
tts_options,
|
||||
):
|
||||
raise TextToSpeechError(
|
||||
code="tts-not-supported",
|
||||
message=(
|
||||
f"Text to speech engine {engine} "
|
||||
f"does not support language {self.pipeline.tts_language} or options {tts_options}"
|
||||
),
|
||||
)
|
||||
except HomeAssistantError as err:
|
||||
raise TextToSpeechError(
|
||||
code="tts-not-supported",
|
||||
message=(
|
||||
f"Text to speech engine {engine} "
|
||||
f"does not support language {self.pipeline.tts_language} or options {tts_options}"
|
||||
),
|
||||
)
|
||||
message=f"Text to speech engine '{engine}' not found",
|
||||
) from err
|
||||
|
||||
self.tts_engine = engine
|
||||
self.tts_options = tts_options
|
||||
@ -596,6 +599,11 @@ class PipelineInput:
|
||||
raise PipelineRunValidationError(
|
||||
"tts_input is required for text to speech"
|
||||
)
|
||||
if self.run.end_stage == PipelineStage.TTS:
|
||||
if self.run.pipeline.tts_engine is None:
|
||||
raise PipelineRunValidationError(
|
||||
"the pipeline does not support text to speech"
|
||||
)
|
||||
|
||||
start_stage_index = PIPELINE_STAGE_ORDER.index(self.run.start_stage)
|
||||
|
||||
|
@ -69,8 +69,8 @@ from .media_source import generate_media_source_id, media_source_id_to_kwargs
|
||||
from .models import Voice
|
||||
|
||||
__all__ = [
|
||||
"async_default_engine",
|
||||
"async_get_media_source_audio",
|
||||
"async_resolve_engine",
|
||||
"async_support_options",
|
||||
"ATTR_AUDIO_OUTPUT",
|
||||
"CONF_LANG",
|
||||
@ -116,6 +116,26 @@ class TTSCache(TypedDict):
|
||||
pending: asyncio.Task | None
|
||||
|
||||
|
||||
@callback
|
||||
def async_default_engine(hass: HomeAssistant) -> str | None:
|
||||
"""Return the domain or entity id of the default engine.
|
||||
|
||||
Returns None if no engines found.
|
||||
"""
|
||||
component: EntityComponent[TextToSpeechEntity] = hass.data[DOMAIN]
|
||||
manager: SpeechManager = hass.data[DATA_TTS_MANAGER]
|
||||
|
||||
if "cloud" in manager.providers:
|
||||
return "cloud"
|
||||
|
||||
entity = next(iter(component.entities), None)
|
||||
|
||||
if entity is not None:
|
||||
return entity.entity_id
|
||||
|
||||
return next(iter(manager.providers), None)
|
||||
|
||||
|
||||
@callback
|
||||
def async_resolve_engine(hass: HomeAssistant, engine: str | None) -> str | None:
|
||||
"""Resolve engine.
|
||||
@ -130,15 +150,7 @@ def async_resolve_engine(hass: HomeAssistant, engine: str | None) -> str | None:
|
||||
return None
|
||||
return engine
|
||||
|
||||
if "cloud" in manager.providers:
|
||||
return "cloud"
|
||||
|
||||
entity = next(iter(component.entities), None)
|
||||
|
||||
if entity is not None:
|
||||
return entity.entity_id
|
||||
|
||||
return next(iter(manager.providers), None)
|
||||
return async_default_engine(hass)
|
||||
|
||||
|
||||
async def async_support_options(
|
||||
|
@ -64,7 +64,7 @@
|
||||
dict({
|
||||
'data': dict({
|
||||
'engine': 'test',
|
||||
'language': None,
|
||||
'language': 'en',
|
||||
'tts_input': "Sorry, I couldn't understand that",
|
||||
'voice': None,
|
||||
}),
|
||||
@ -73,7 +73,7 @@
|
||||
dict({
|
||||
'data': dict({
|
||||
'tts_output': dict({
|
||||
'media_id': "media-source://tts/test?message=Sorry,+I+couldn't+understand+that",
|
||||
'media_id': "media-source://tts/test?message=Sorry,+I+couldn't+understand+that&language=en",
|
||||
'mime_type': 'audio/mpeg',
|
||||
'url': '/api/tts_proxy/dae2cdcb27a1d1c3b07ba2c7db91480f9d4bfd8f_en-us_-_test.mp3',
|
||||
}),
|
||||
|
@ -61,7 +61,7 @@
|
||||
# name: test_audio_pipeline.5
|
||||
dict({
|
||||
'engine': 'test',
|
||||
'language': None,
|
||||
'language': 'en',
|
||||
'tts_input': "Sorry, I couldn't understand that",
|
||||
'voice': None,
|
||||
})
|
||||
@ -69,7 +69,7 @@
|
||||
# name: test_audio_pipeline.6
|
||||
dict({
|
||||
'tts_output': dict({
|
||||
'media_id': "media-source://tts/test?message=Sorry,+I+couldn't+understand+that",
|
||||
'media_id': "media-source://tts/test?message=Sorry,+I+couldn't+understand+that&language=en",
|
||||
'mime_type': 'audio/mpeg',
|
||||
'url': '/api/tts_proxy/dae2cdcb27a1d1c3b07ba2c7db91480f9d4bfd8f_en-us_-_test.mp3',
|
||||
}),
|
||||
@ -137,7 +137,7 @@
|
||||
# name: test_audio_pipeline_debug.5
|
||||
dict({
|
||||
'engine': 'test',
|
||||
'language': None,
|
||||
'language': 'en',
|
||||
'tts_input': "Sorry, I couldn't understand that",
|
||||
'voice': None,
|
||||
})
|
||||
@ -145,7 +145,7 @@
|
||||
# name: test_audio_pipeline_debug.6
|
||||
dict({
|
||||
'tts_output': dict({
|
||||
'media_id': "media-source://tts/test?message=Sorry,+I+couldn't+understand+that",
|
||||
'media_id': "media-source://tts/test?message=Sorry,+I+couldn't+understand+that&language=en",
|
||||
'mime_type': 'audio/mpeg',
|
||||
'url': '/api/tts_proxy/dae2cdcb27a1d1c3b07ba2c7db91480f9d4bfd8f_en-us_-_test.mp3',
|
||||
}),
|
||||
@ -295,7 +295,7 @@
|
||||
# name: test_tts_failed.1
|
||||
dict({
|
||||
'engine': 'test',
|
||||
'language': None,
|
||||
'language': 'en',
|
||||
'tts_input': 'Lights are on.',
|
||||
'voice': None,
|
||||
})
|
||||
|
Loading…
x
Reference in New Issue
Block a user