mirror of
https://github.com/home-assistant/core.git
synced 2025-07-21 12:17:07 +00:00
Don't resolve default tts engine in assist pipelines (#91943)
* Don't resolve default tts engine in assist pipelines * Set tts engine when creating default pipeline * Update tests
This commit is contained in:
parent
392a9f32c9
commit
1c3e1d2e13
@ -15,6 +15,7 @@ from homeassistant.components.tts.media_source import (
|
|||||||
generate_media_source_id as tts_generate_media_source_id,
|
generate_media_source_id as tts_generate_media_source_id,
|
||||||
)
|
)
|
||||||
from homeassistant.core import Context, HomeAssistant, callback
|
from homeassistant.core import Context, HomeAssistant, callback
|
||||||
|
from homeassistant.exceptions import HomeAssistantError
|
||||||
from homeassistant.helpers.collection import (
|
from homeassistant.helpers.collection import (
|
||||||
CollectionError,
|
CollectionError,
|
||||||
ItemNotFound,
|
ItemNotFound,
|
||||||
@ -85,6 +86,8 @@ async def async_get_pipeline(
|
|||||||
# configured language
|
# configured language
|
||||||
stt_engine = stt.async_default_provider(hass)
|
stt_engine = stt.async_default_provider(hass)
|
||||||
stt_language = hass.config.language if stt_engine else None
|
stt_language = hass.config.language if stt_engine else None
|
||||||
|
tts_engine = tts.async_default_engine(hass)
|
||||||
|
tts_language = hass.config.language if tts_engine else None
|
||||||
return await pipeline_data.pipeline_store.async_create_item(
|
return await pipeline_data.pipeline_store.async_create_item(
|
||||||
{
|
{
|
||||||
"conversation_engine": None,
|
"conversation_engine": None,
|
||||||
@ -93,8 +96,8 @@ async def async_get_pipeline(
|
|||||||
"name": hass.config.language,
|
"name": hass.config.language,
|
||||||
"stt_engine": stt_engine,
|
"stt_engine": stt_engine,
|
||||||
"stt_language": stt_language,
|
"stt_language": stt_language,
|
||||||
"tts_engine": None,
|
"tts_engine": tts_engine,
|
||||||
"tts_language": None,
|
"tts_language": tts_language,
|
||||||
"tts_voice": None,
|
"tts_voice": None,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
@ -420,14 +423,7 @@ class PipelineRun:
|
|||||||
|
|
||||||
async def prepare_text_to_speech(self) -> None:
|
async def prepare_text_to_speech(self) -> None:
|
||||||
"""Prepare text to speech."""
|
"""Prepare text to speech."""
|
||||||
engine = tts.async_resolve_engine(self.hass, self.pipeline.tts_engine)
|
engine = self.pipeline.tts_engine
|
||||||
|
|
||||||
if engine is None:
|
|
||||||
engine = self.pipeline.tts_engine or "default"
|
|
||||||
raise TextToSpeechError(
|
|
||||||
code="tts-not-supported",
|
|
||||||
message=f"Text to speech engine '{engine}' not found",
|
|
||||||
)
|
|
||||||
|
|
||||||
tts_options = {}
|
tts_options = {}
|
||||||
if self.pipeline.tts_voice is not None:
|
if self.pipeline.tts_voice is not None:
|
||||||
@ -436,19 +432,26 @@ class PipelineRun:
|
|||||||
if self.tts_audio_output is not None:
|
if self.tts_audio_output is not None:
|
||||||
tts_options[tts.ATTR_AUDIO_OUTPUT] = self.tts_audio_output
|
tts_options[tts.ATTR_AUDIO_OUTPUT] = self.tts_audio_output
|
||||||
|
|
||||||
if not await tts.async_support_options(
|
try:
|
||||||
self.hass,
|
# pipeline.tts_engine can't be None or this function is not called
|
||||||
engine,
|
if not await tts.async_support_options(
|
||||||
self.pipeline.tts_language,
|
self.hass,
|
||||||
tts_options,
|
engine, # type: ignore[arg-type]
|
||||||
):
|
self.pipeline.tts_language,
|
||||||
|
tts_options,
|
||||||
|
):
|
||||||
|
raise TextToSpeechError(
|
||||||
|
code="tts-not-supported",
|
||||||
|
message=(
|
||||||
|
f"Text to speech engine {engine} "
|
||||||
|
f"does not support language {self.pipeline.tts_language} or options {tts_options}"
|
||||||
|
),
|
||||||
|
)
|
||||||
|
except HomeAssistantError as err:
|
||||||
raise TextToSpeechError(
|
raise TextToSpeechError(
|
||||||
code="tts-not-supported",
|
code="tts-not-supported",
|
||||||
message=(
|
message=f"Text to speech engine '{engine}' not found",
|
||||||
f"Text to speech engine {engine} "
|
) from err
|
||||||
f"does not support language {self.pipeline.tts_language} or options {tts_options}"
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
||||||
self.tts_engine = engine
|
self.tts_engine = engine
|
||||||
self.tts_options = tts_options
|
self.tts_options = tts_options
|
||||||
@ -596,6 +599,11 @@ class PipelineInput:
|
|||||||
raise PipelineRunValidationError(
|
raise PipelineRunValidationError(
|
||||||
"tts_input is required for text to speech"
|
"tts_input is required for text to speech"
|
||||||
)
|
)
|
||||||
|
if self.run.end_stage == PipelineStage.TTS:
|
||||||
|
if self.run.pipeline.tts_engine is None:
|
||||||
|
raise PipelineRunValidationError(
|
||||||
|
"the pipeline does not support text to speech"
|
||||||
|
)
|
||||||
|
|
||||||
start_stage_index = PIPELINE_STAGE_ORDER.index(self.run.start_stage)
|
start_stage_index = PIPELINE_STAGE_ORDER.index(self.run.start_stage)
|
||||||
|
|
||||||
|
@ -69,8 +69,8 @@ from .media_source import generate_media_source_id, media_source_id_to_kwargs
|
|||||||
from .models import Voice
|
from .models import Voice
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
|
"async_default_engine",
|
||||||
"async_get_media_source_audio",
|
"async_get_media_source_audio",
|
||||||
"async_resolve_engine",
|
|
||||||
"async_support_options",
|
"async_support_options",
|
||||||
"ATTR_AUDIO_OUTPUT",
|
"ATTR_AUDIO_OUTPUT",
|
||||||
"CONF_LANG",
|
"CONF_LANG",
|
||||||
@ -116,6 +116,26 @@ class TTSCache(TypedDict):
|
|||||||
pending: asyncio.Task | None
|
pending: asyncio.Task | None
|
||||||
|
|
||||||
|
|
||||||
|
@callback
|
||||||
|
def async_default_engine(hass: HomeAssistant) -> str | None:
|
||||||
|
"""Return the domain or entity id of the default engine.
|
||||||
|
|
||||||
|
Returns None if no engines found.
|
||||||
|
"""
|
||||||
|
component: EntityComponent[TextToSpeechEntity] = hass.data[DOMAIN]
|
||||||
|
manager: SpeechManager = hass.data[DATA_TTS_MANAGER]
|
||||||
|
|
||||||
|
if "cloud" in manager.providers:
|
||||||
|
return "cloud"
|
||||||
|
|
||||||
|
entity = next(iter(component.entities), None)
|
||||||
|
|
||||||
|
if entity is not None:
|
||||||
|
return entity.entity_id
|
||||||
|
|
||||||
|
return next(iter(manager.providers), None)
|
||||||
|
|
||||||
|
|
||||||
@callback
|
@callback
|
||||||
def async_resolve_engine(hass: HomeAssistant, engine: str | None) -> str | None:
|
def async_resolve_engine(hass: HomeAssistant, engine: str | None) -> str | None:
|
||||||
"""Resolve engine.
|
"""Resolve engine.
|
||||||
@ -130,15 +150,7 @@ def async_resolve_engine(hass: HomeAssistant, engine: str | None) -> str | None:
|
|||||||
return None
|
return None
|
||||||
return engine
|
return engine
|
||||||
|
|
||||||
if "cloud" in manager.providers:
|
return async_default_engine(hass)
|
||||||
return "cloud"
|
|
||||||
|
|
||||||
entity = next(iter(component.entities), None)
|
|
||||||
|
|
||||||
if entity is not None:
|
|
||||||
return entity.entity_id
|
|
||||||
|
|
||||||
return next(iter(manager.providers), None)
|
|
||||||
|
|
||||||
|
|
||||||
async def async_support_options(
|
async def async_support_options(
|
||||||
|
@ -64,7 +64,7 @@
|
|||||||
dict({
|
dict({
|
||||||
'data': dict({
|
'data': dict({
|
||||||
'engine': 'test',
|
'engine': 'test',
|
||||||
'language': None,
|
'language': 'en',
|
||||||
'tts_input': "Sorry, I couldn't understand that",
|
'tts_input': "Sorry, I couldn't understand that",
|
||||||
'voice': None,
|
'voice': None,
|
||||||
}),
|
}),
|
||||||
@ -73,7 +73,7 @@
|
|||||||
dict({
|
dict({
|
||||||
'data': dict({
|
'data': dict({
|
||||||
'tts_output': dict({
|
'tts_output': dict({
|
||||||
'media_id': "media-source://tts/test?message=Sorry,+I+couldn't+understand+that",
|
'media_id': "media-source://tts/test?message=Sorry,+I+couldn't+understand+that&language=en",
|
||||||
'mime_type': 'audio/mpeg',
|
'mime_type': 'audio/mpeg',
|
||||||
'url': '/api/tts_proxy/dae2cdcb27a1d1c3b07ba2c7db91480f9d4bfd8f_en-us_-_test.mp3',
|
'url': '/api/tts_proxy/dae2cdcb27a1d1c3b07ba2c7db91480f9d4bfd8f_en-us_-_test.mp3',
|
||||||
}),
|
}),
|
||||||
|
@ -61,7 +61,7 @@
|
|||||||
# name: test_audio_pipeline.5
|
# name: test_audio_pipeline.5
|
||||||
dict({
|
dict({
|
||||||
'engine': 'test',
|
'engine': 'test',
|
||||||
'language': None,
|
'language': 'en',
|
||||||
'tts_input': "Sorry, I couldn't understand that",
|
'tts_input': "Sorry, I couldn't understand that",
|
||||||
'voice': None,
|
'voice': None,
|
||||||
})
|
})
|
||||||
@ -69,7 +69,7 @@
|
|||||||
# name: test_audio_pipeline.6
|
# name: test_audio_pipeline.6
|
||||||
dict({
|
dict({
|
||||||
'tts_output': dict({
|
'tts_output': dict({
|
||||||
'media_id': "media-source://tts/test?message=Sorry,+I+couldn't+understand+that",
|
'media_id': "media-source://tts/test?message=Sorry,+I+couldn't+understand+that&language=en",
|
||||||
'mime_type': 'audio/mpeg',
|
'mime_type': 'audio/mpeg',
|
||||||
'url': '/api/tts_proxy/dae2cdcb27a1d1c3b07ba2c7db91480f9d4bfd8f_en-us_-_test.mp3',
|
'url': '/api/tts_proxy/dae2cdcb27a1d1c3b07ba2c7db91480f9d4bfd8f_en-us_-_test.mp3',
|
||||||
}),
|
}),
|
||||||
@ -137,7 +137,7 @@
|
|||||||
# name: test_audio_pipeline_debug.5
|
# name: test_audio_pipeline_debug.5
|
||||||
dict({
|
dict({
|
||||||
'engine': 'test',
|
'engine': 'test',
|
||||||
'language': None,
|
'language': 'en',
|
||||||
'tts_input': "Sorry, I couldn't understand that",
|
'tts_input': "Sorry, I couldn't understand that",
|
||||||
'voice': None,
|
'voice': None,
|
||||||
})
|
})
|
||||||
@ -145,7 +145,7 @@
|
|||||||
# name: test_audio_pipeline_debug.6
|
# name: test_audio_pipeline_debug.6
|
||||||
dict({
|
dict({
|
||||||
'tts_output': dict({
|
'tts_output': dict({
|
||||||
'media_id': "media-source://tts/test?message=Sorry,+I+couldn't+understand+that",
|
'media_id': "media-source://tts/test?message=Sorry,+I+couldn't+understand+that&language=en",
|
||||||
'mime_type': 'audio/mpeg',
|
'mime_type': 'audio/mpeg',
|
||||||
'url': '/api/tts_proxy/dae2cdcb27a1d1c3b07ba2c7db91480f9d4bfd8f_en-us_-_test.mp3',
|
'url': '/api/tts_proxy/dae2cdcb27a1d1c3b07ba2c7db91480f9d4bfd8f_en-us_-_test.mp3',
|
||||||
}),
|
}),
|
||||||
@ -295,7 +295,7 @@
|
|||||||
# name: test_tts_failed.1
|
# name: test_tts_failed.1
|
||||||
dict({
|
dict({
|
||||||
'engine': 'test',
|
'engine': 'test',
|
||||||
'language': None,
|
'language': 'en',
|
||||||
'tts_input': 'Lights are on.',
|
'tts_input': 'Lights are on.',
|
||||||
'voice': None,
|
'voice': None,
|
||||||
})
|
})
|
||||||
|
Loading…
x
Reference in New Issue
Block a user