Filter preferred TTS format options if not supported (#114392)

Filter preferred format options if not supported
This commit is contained in:
Michael Hansen 2024-03-28 11:09:15 -05:00 committed by GitHub
parent 3df03f5be5
commit 6fafb9c9b4
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 150 additions and 32 deletions

View File

@ -16,7 +16,7 @@ import os
import re import re
import subprocess import subprocess
import tempfile import tempfile
from typing import Any, TypedDict, final from typing import Any, Final, TypedDict, final
from aiohttp import web from aiohttp import web
import mutagen import mutagen
@ -99,6 +99,13 @@ ATTR_PREFERRED_SAMPLE_CHANNELS = "preferred_sample_channels"
ATTR_MEDIA_PLAYER_ENTITY_ID = "media_player_entity_id" ATTR_MEDIA_PLAYER_ENTITY_ID = "media_player_entity_id"
ATTR_VOICE = "voice" ATTR_VOICE = "voice"
_DEFAULT_FORMAT = "mp3"
_PREFFERED_FORMAT_OPTIONS: Final[set[str]] = {
ATTR_PREFERRED_FORMAT,
ATTR_PREFERRED_SAMPLE_RATE,
ATTR_PREFERRED_SAMPLE_CHANNELS,
}
CONF_LANG = "language" CONF_LANG = "language"
SERVICE_CLEAR_CACHE = "clear_cache" SERVICE_CLEAR_CACHE = "clear_cache"
@ -569,25 +576,23 @@ class SpeechManager:
): ):
raise HomeAssistantError(f"Language '{language}' not supported") raise HomeAssistantError(f"Language '{language}' not supported")
options = options or {}
supported_options = engine_instance.supported_options or []
# Update default options with provided options # Update default options with provided options
invalid_opts: list[str] = []
merged_options = dict(engine_instance.default_options or {}) merged_options = dict(engine_instance.default_options or {})
merged_options.update(options or {}) for option_name, option_value in options.items():
# Only count an option as invalid if it's not a "preferred format"
# option. These are used as hints to the TTS system if supported,
# and otherwise as parameters to ffmpeg conversion.
if (option_name in supported_options) or (
option_name in _PREFFERED_FORMAT_OPTIONS
):
merged_options[option_name] = option_value
else:
invalid_opts.append(option_name)
supported_options = list(engine_instance.supported_options or [])
# ATTR_PREFERRED_* options are always "supported" since they're used to
# convert audio after the TTS has run (if necessary).
supported_options.extend(
(
ATTR_PREFERRED_FORMAT,
ATTR_PREFERRED_SAMPLE_RATE,
ATTR_PREFERRED_SAMPLE_CHANNELS,
)
)
invalid_opts = [
opt_name for opt_name in merged_options if opt_name not in supported_options
]
if invalid_opts: if invalid_opts:
raise HomeAssistantError(f"Invalid options found: {invalid_opts}") raise HomeAssistantError(f"Invalid options found: {invalid_opts}")
@ -687,10 +692,31 @@ class SpeechManager:
This method is a coroutine. This method is a coroutine.
""" """
options = options or {} options = dict(options or {})
supported_options = engine_instance.supported_options or []
# Default to MP3 unless a different format is preferred # Extract preferred format options.
final_extension = options.get(ATTR_PREFERRED_FORMAT, "mp3") #
# These options are used by Assist pipelines, etc. to get a format that
# the voice satellite will support.
#
# The TTS system ideally supports options directly so we won't have
# to convert with ffmpeg later. If not, we pop the options here and
# perform the conversation after receiving the audio.
if ATTR_PREFERRED_FORMAT in supported_options:
final_extension = options.get(ATTR_PREFERRED_FORMAT, _DEFAULT_FORMAT)
else:
final_extension = options.pop(ATTR_PREFERRED_FORMAT, _DEFAULT_FORMAT)
if ATTR_PREFERRED_SAMPLE_RATE in supported_options:
sample_rate = options.get(ATTR_PREFERRED_SAMPLE_RATE)
else:
sample_rate = options.pop(ATTR_PREFERRED_SAMPLE_RATE, None)
if ATTR_PREFERRED_SAMPLE_CHANNELS in supported_options:
sample_channels = options.get(ATTR_PREFERRED_SAMPLE_CHANNELS)
else:
sample_channels = options.pop(ATTR_PREFERRED_SAMPLE_CHANNELS, None)
async def get_tts_data() -> str: async def get_tts_data() -> str:
"""Handle data available.""" """Handle data available."""
@ -716,8 +742,8 @@ class SpeechManager:
# rate/format/channel count is requested. # rate/format/channel count is requested.
needs_conversion = ( needs_conversion = (
(final_extension != extension) (final_extension != extension)
or (ATTR_PREFERRED_SAMPLE_RATE in options) or (sample_rate is not None)
or (ATTR_PREFERRED_SAMPLE_CHANNELS in options) or (sample_channels is not None)
) )
if needs_conversion: if needs_conversion:
@ -726,8 +752,8 @@ class SpeechManager:
extension, extension,
data, data,
to_extension=final_extension, to_extension=final_extension,
to_sample_rate=options.get(ATTR_PREFERRED_SAMPLE_RATE), to_sample_rate=sample_rate,
to_sample_channels=options.get(ATTR_PREFERRED_SAMPLE_CHANNELS), to_sample_channels=sample_channels,
) )
# Create file infos # Create file infos

View File

@ -111,6 +111,7 @@ class MockTTSProvider(tts.Provider):
tts.Voice("fran_drescher", "Fran Drescher"), tts.Voice("fran_drescher", "Fran Drescher"),
] ]
} }
_supported_options = ["voice", "age", tts.ATTR_AUDIO_OUTPUT]
@property @property
def default_language(self) -> str: def default_language(self) -> str:
@ -130,7 +131,7 @@ class MockTTSProvider(tts.Provider):
@property @property
def supported_options(self) -> list[str]: def supported_options(self) -> list[str]:
"""Return list of supported options like voice, emotions.""" """Return list of supported options like voice, emotions."""
return ["voice", "age", tts.ATTR_AUDIO_OUTPUT] return self._supported_options
def get_tts_audio( def get_tts_audio(
self, message: str, language: str, options: dict[str, Any] self, message: str, language: str, options: dict[str, Any]

View File

@ -11,7 +11,7 @@ import wave
import pytest import pytest
from syrupy.assertion import SnapshotAssertion from syrupy.assertion import SnapshotAssertion
from homeassistant.components import assist_pipeline, stt, tts from homeassistant.components import assist_pipeline, media_source, stt, tts
from homeassistant.components.assist_pipeline.const import ( from homeassistant.components.assist_pipeline.const import (
CONF_DEBUG_RECORDING_DIR, CONF_DEBUG_RECORDING_DIR,
DOMAIN, DOMAIN,
@ -19,9 +19,14 @@ from homeassistant.components.assist_pipeline.const import (
from homeassistant.core import Context, HomeAssistant from homeassistant.core import Context, HomeAssistant
from homeassistant.setup import async_setup_component from homeassistant.setup import async_setup_component
from .conftest import MockSttProvider, MockSttProviderEntity, MockWakeWordEntity from .conftest import (
MockSttProvider,
MockSttProviderEntity,
MockTTSProvider,
MockWakeWordEntity,
)
from tests.typing import WebSocketGenerator from tests.typing import ClientSessionGenerator, WebSocketGenerator
BYTES_ONE_SECOND = 16000 * 2 BYTES_ONE_SECOND = 16000 * 2
@ -729,15 +734,17 @@ def test_pipeline_run_equality(hass: HomeAssistant, init_components) -> None:
async def test_tts_audio_output( async def test_tts_audio_output(
hass: HomeAssistant, hass: HomeAssistant,
mock_stt_provider: MockSttProvider, hass_client: ClientSessionGenerator,
mock_tts_provider: MockTTSProvider,
init_components, init_components,
pipeline_data: assist_pipeline.pipeline.PipelineData, pipeline_data: assist_pipeline.pipeline.PipelineData,
snapshot: SnapshotAssertion, snapshot: SnapshotAssertion,
) -> None: ) -> None:
"""Test using tts_audio_output with wav sets options correctly.""" """Test using tts_audio_output with wav sets options correctly."""
client = await hass_client()
assert await async_setup_component(hass, media_source.DOMAIN, {})
def event_callback(event): events: list[assist_pipeline.PipelineEvent] = []
pass
pipeline_store = pipeline_data.pipeline_store pipeline_store = pipeline_data.pipeline_store
pipeline_id = pipeline_store.async_get_preferred_item() pipeline_id = pipeline_store.async_get_preferred_item()
@ -753,7 +760,7 @@ async def test_tts_audio_output(
pipeline=pipeline, pipeline=pipeline,
start_stage=assist_pipeline.PipelineStage.TTS, start_stage=assist_pipeline.PipelineStage.TTS,
end_stage=assist_pipeline.PipelineStage.TTS, end_stage=assist_pipeline.PipelineStage.TTS,
event_callback=event_callback, event_callback=events.append,
tts_audio_output="wav", tts_audio_output="wav",
), ),
) )
@ -764,3 +771,87 @@ async def test_tts_audio_output(
assert pipeline_input.run.tts_options.get(tts.ATTR_PREFERRED_FORMAT) == "wav" assert pipeline_input.run.tts_options.get(tts.ATTR_PREFERRED_FORMAT) == "wav"
assert pipeline_input.run.tts_options.get(tts.ATTR_PREFERRED_SAMPLE_RATE) == 16000 assert pipeline_input.run.tts_options.get(tts.ATTR_PREFERRED_SAMPLE_RATE) == 16000
assert pipeline_input.run.tts_options.get(tts.ATTR_PREFERRED_SAMPLE_CHANNELS) == 1 assert pipeline_input.run.tts_options.get(tts.ATTR_PREFERRED_SAMPLE_CHANNELS) == 1
with patch.object(mock_tts_provider, "get_tts_audio") as mock_get_tts_audio:
await pipeline_input.execute()
for event in events:
if event.type == assist_pipeline.PipelineEventType.TTS_END:
# We must fetch the media URL to trigger the TTS
assert event.data
media_id = event.data["tts_output"]["media_id"]
resolved = await media_source.async_resolve_media(hass, media_id, None)
await client.get(resolved.url)
# Ensure that no unsupported options were passed in
assert mock_get_tts_audio.called
options = mock_get_tts_audio.call_args_list[0].kwargs["options"]
extra_options = set(options).difference(mock_tts_provider.supported_options)
assert len(extra_options) == 0, extra_options
async def test_tts_supports_preferred_format(
hass: HomeAssistant,
hass_client: ClientSessionGenerator,
mock_tts_provider: MockTTSProvider,
init_components,
pipeline_data: assist_pipeline.pipeline.PipelineData,
snapshot: SnapshotAssertion,
) -> None:
"""Test that preferred format options are given to the TTS system if supported."""
client = await hass_client()
assert await async_setup_component(hass, media_source.DOMAIN, {})
events: list[assist_pipeline.PipelineEvent] = []
pipeline_store = pipeline_data.pipeline_store
pipeline_id = pipeline_store.async_get_preferred_item()
pipeline = assist_pipeline.pipeline.async_get_pipeline(hass, pipeline_id)
pipeline_input = assist_pipeline.pipeline.PipelineInput(
tts_input="This is a test.",
conversation_id=None,
device_id=None,
run=assist_pipeline.pipeline.PipelineRun(
hass,
context=Context(),
pipeline=pipeline,
start_stage=assist_pipeline.PipelineStage.TTS,
end_stage=assist_pipeline.PipelineStage.TTS,
event_callback=events.append,
tts_audio_output="wav",
),
)
await pipeline_input.validate()
# Make the TTS provider support preferred format options
supported_options = list(mock_tts_provider.supported_options or [])
supported_options.extend(
[
tts.ATTR_PREFERRED_FORMAT,
tts.ATTR_PREFERRED_SAMPLE_RATE,
tts.ATTR_PREFERRED_SAMPLE_CHANNELS,
]
)
with (
patch.object(mock_tts_provider, "_supported_options", supported_options),
patch.object(mock_tts_provider, "get_tts_audio") as mock_get_tts_audio,
):
await pipeline_input.execute()
for event in events:
if event.type == assist_pipeline.PipelineEventType.TTS_END:
# We must fetch the media URL to trigger the TTS
assert event.data
media_id = event.data["tts_output"]["media_id"]
resolved = await media_source.async_resolve_media(hass, media_id, None)
await client.get(resolved.url)
assert mock_get_tts_audio.called
options = mock_get_tts_audio.call_args_list[0].kwargs["options"]
# We should have received preferred format options in get_tts_audio
assert tts.ATTR_PREFERRED_FORMAT in options
assert tts.ATTR_PREFERRED_SAMPLE_RATE in options
assert tts.ATTR_PREFERRED_SAMPLE_CHANNELS in options