Make tts options of type list (such as profiles in google_cloud) work (#121582)

* Allow tts options of type list such as profiles in google_cloud

* Update tests/components/tts/test_media_source.py

* Don't mix engine specific options with other options

* Fix test

* Update assist_pipeline snapshots

---------

Co-authored-by: Paulus Schoutsen <balloob@gmail.com>
Co-authored-by: Paulus Schoutsen <paulus@home-assistant.io>
Co-authored-by: Erik Montnemery <erik@montnemery.com>
This commit is contained in:
tronikos 2024-09-24 00:07:12 -07:00 committed by GitHub
parent 615ec548db
commit 4c0fb04f61
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 150 additions and 29 deletions

View File

@ -2,6 +2,7 @@
from __future__ import annotations
import json
import mimetypes
from typing import TypedDict
@ -22,6 +23,8 @@ from homeassistant.exceptions import HomeAssistantError
from .const import DATA_TTS_MANAGER, DOMAIN, DOMAIN_DATA
from .helper import get_engine_instance
URL_QUERY_TTS_OPTIONS = "tts_options"
async def async_get_media_source(hass: HomeAssistant) -> TTSMediaSource:
"""Set up tts media source."""
@ -55,8 +58,7 @@ def generate_media_source_id(
params["cache"] = "true" if cache else "false"
if language is not None:
params["language"] = language
if options is not None:
params.update(options)
params[URL_QUERY_TTS_OPTIONS] = json.dumps(options, separators=(",", ":"))
return ms_generate_media_source_id(
DOMAIN,
@ -78,19 +80,28 @@ class MediaSourceOptions(TypedDict):
def media_source_id_to_kwargs(media_source_id: str) -> MediaSourceOptions:
"""Turn a media source ID into options."""
parsed = URL(media_source_id)
if URL_QUERY_TTS_OPTIONS in parsed.query:
try:
options = json.loads(parsed.query[URL_QUERY_TTS_OPTIONS])
except json.JSONDecodeError as err:
raise Unresolvable(f"Invalid TTS options: {err.msg}") from err
else:
options = {
k: v
for k, v in parsed.query.items()
if k not in ("message", "language", "cache")
}
if "message" not in parsed.query:
raise Unresolvable("No message specified.")
options = dict(parsed.query)
kwargs: MediaSourceOptions = {
"engine": parsed.name,
"message": options.pop("message"),
"language": options.pop("language", None),
"message": parsed.query["message"],
"language": parsed.query.get("language"),
"options": options,
"cache": None,
}
if "cache" in options:
kwargs["cache"] = options.pop("cache") == "true"
if "cache" in parsed.query:
kwargs["cache"] = parsed.query["cache"] == "true"
return kwargs
@ -111,6 +122,8 @@ class TTSMediaSource(MediaSource):
url = await self.hass.data[DATA_TTS_MANAGER].async_get_url_path(
**media_source_id_to_kwargs(item.identifier)
)
except Unresolvable:
raise
except HomeAssistantError as err:
raise Unresolvable(str(err)) from err

View File

@ -75,7 +75,7 @@
dict({
'data': dict({
'tts_output': dict({
'media_id': "media-source://tts/test?message=Sorry,+I+couldn't+understand+that&language=en-US&voice=james_earl_jones",
'media_id': "media-source://tts/test?message=Sorry,+I+couldn't+understand+that&language=en-US&tts_options=%7B%22voice%22:%22james_earl_jones%22%7D",
'mime_type': 'audio/mpeg',
'url': '/api/tts_proxy/dae2cdcb27a1d1c3b07ba2c7db91480f9d4bfd8f_en-us_031e2ec052_test.mp3',
}),
@ -164,7 +164,7 @@
dict({
'data': dict({
'tts_output': dict({
'media_id': "media-source://tts/test?message=Sorry,+I+couldn't+understand+that&language=en-US&voice=Arnold+Schwarzenegger",
'media_id': "media-source://tts/test?message=Sorry,+I+couldn't+understand+that&language=en-US&tts_options=%7B%22voice%22:%22Arnold+Schwarzenegger%22%7D",
'mime_type': 'audio/mpeg',
'url': '/api/tts_proxy/dae2cdcb27a1d1c3b07ba2c7db91480f9d4bfd8f_en-us_2657c1a8ee_test.mp3',
}),
@ -253,7 +253,7 @@
dict({
'data': dict({
'tts_output': dict({
'media_id': "media-source://tts/test?message=Sorry,+I+couldn't+understand+that&language=en-US&voice=Arnold+Schwarzenegger",
'media_id': "media-source://tts/test?message=Sorry,+I+couldn't+understand+that&language=en-US&tts_options=%7B%22voice%22:%22Arnold+Schwarzenegger%22%7D",
'mime_type': 'audio/mpeg',
'url': '/api/tts_proxy/dae2cdcb27a1d1c3b07ba2c7db91480f9d4bfd8f_en-us_2657c1a8ee_test.mp3',
}),
@ -366,7 +366,7 @@
dict({
'data': dict({
'tts_output': dict({
'media_id': "media-source://tts/test?message=Sorry,+I+couldn't+understand+that&language=en-US&voice=james_earl_jones",
'media_id': "media-source://tts/test?message=Sorry,+I+couldn't+understand+that&language=en-US&tts_options=%7B%22voice%22:%22james_earl_jones%22%7D",
'mime_type': 'audio/mpeg',
'url': '/api/tts_proxy/dae2cdcb27a1d1c3b07ba2c7db91480f9d4bfd8f_en-us_031e2ec052_test.mp3',
}),

View File

@ -71,7 +71,7 @@
# name: test_audio_pipeline.6
dict({
'tts_output': dict({
'media_id': "media-source://tts/test?message=Sorry,+I+couldn't+understand+that&language=en-US&voice=james_earl_jones",
'media_id': "media-source://tts/test?message=Sorry,+I+couldn't+understand+that&language=en-US&tts_options=%7B%22voice%22:%22james_earl_jones%22%7D",
'mime_type': 'audio/mpeg',
'url': '/api/tts_proxy/dae2cdcb27a1d1c3b07ba2c7db91480f9d4bfd8f_en-us_031e2ec052_test.mp3',
}),
@ -152,7 +152,7 @@
# name: test_audio_pipeline_debug.6
dict({
'tts_output': dict({
'media_id': "media-source://tts/test?message=Sorry,+I+couldn't+understand+that&language=en-US&voice=james_earl_jones",
'media_id': "media-source://tts/test?message=Sorry,+I+couldn't+understand+that&language=en-US&tts_options=%7B%22voice%22:%22james_earl_jones%22%7D",
'mime_type': 'audio/mpeg',
'url': '/api/tts_proxy/dae2cdcb27a1d1c3b07ba2c7db91480f9d4bfd8f_en-us_031e2ec052_test.mp3',
}),
@ -245,7 +245,7 @@
# name: test_audio_pipeline_with_enhancements.6
dict({
'tts_output': dict({
'media_id': "media-source://tts/test?message=Sorry,+I+couldn't+understand+that&language=en-US&voice=james_earl_jones",
'media_id': "media-source://tts/test?message=Sorry,+I+couldn't+understand+that&language=en-US&tts_options=%7B%22voice%22:%22james_earl_jones%22%7D",
'mime_type': 'audio/mpeg',
'url': '/api/tts_proxy/dae2cdcb27a1d1c3b07ba2c7db91480f9d4bfd8f_en-us_031e2ec052_test.mp3',
}),
@ -348,7 +348,7 @@
# name: test_audio_pipeline_with_wake_word_no_timeout.8
dict({
'tts_output': dict({
'media_id': "media-source://tts/test?message=Sorry,+I+couldn't+understand+that&language=en-US&voice=james_earl_jones",
'media_id': "media-source://tts/test?message=Sorry,+I+couldn't+understand+that&language=en-US&tts_options=%7B%22voice%22:%22james_earl_jones%22%7D",
'mime_type': 'audio/mpeg',
'url': '/api/tts_proxy/dae2cdcb27a1d1c3b07ba2c7db91480f9d4bfd8f_en-us_031e2ec052_test.mp3',
}),

View File

@ -1318,10 +1318,16 @@ async def test_tags_with_wave() -> None:
@pytest.mark.parametrize(
("engine", "language", "options", "cache", "result_query"),
[
(None, None, None, None, ""),
(None, "de_DE", None, None, "language=de_DE"),
(None, "de_DE", {"voice": "henk"}, None, "language=de_DE&voice=henk"),
(None, "de_DE", None, True, "cache=true&language=de_DE"),
(None, None, None, None, "&tts_options=null"),
(None, "de_DE", None, None, "&language=de_DE&tts_options=null"),
(
None,
"de_DE",
{"voice": "henk"},
None,
"&language=de_DE&tts_options=%7B%22voice%22:%22henk%22%7D",
),
(None, "de_DE", None, True, "&cache=true&language=de_DE&tts_options=null"),
],
)
async def test_generate_media_source_id(
@ -1343,8 +1349,9 @@ async def test_generate_media_source_id(
_, _, engine_query = media_source_id.rpartition("/")
engine, _, query = engine_query.partition("?")
assert engine == result_engine
assert query.startswith("message=msg")
assert query[12:] == result_query
query_prefix = "message=msg"
assert query.startswith(query_prefix)
assert query[len(query_prefix) :] == result_query
@pytest.mark.parametrize(

View File

@ -8,6 +8,11 @@ import pytest
from homeassistant.components import media_source
from homeassistant.components.media_player import BrowseError
from homeassistant.components.tts.media_source import (
MediaSourceOptions,
generate_media_source_id,
media_source_id_to_kwargs,
)
from homeassistant.core import HomeAssistant
from homeassistant.setup import async_setup_component
@ -93,14 +98,24 @@ async def test_browsing(hass: HomeAssistant, setup: str) -> None:
await media_source.async_browse_media(hass, "media-source://tts/non-existing")
@pytest.mark.parametrize("mock_provider", [MSProvider(DEFAULT_LANG)])
@pytest.mark.parametrize(
("mock_provider", "extra_options"),
[
(MSProvider(DEFAULT_LANG), "&tts_options=%7B%22voice%22%3A%22Paulus%22%7D"),
(MSProvider(DEFAULT_LANG), "&voice=Paulus"),
],
)
async def test_legacy_resolving(
hass: HomeAssistant, hass_client: ClientSessionGenerator, mock_provider: MSProvider
hass: HomeAssistant,
hass_client: ClientSessionGenerator,
mock_provider: MSProvider,
extra_options: str,
) -> None:
"""Test resolving legacy provider."""
await mock_setup(hass, mock_provider)
mock_get_tts_audio = mock_provider.get_tts_audio
mock_get_tts_audio.reset_mock()
media_id = "media-source://tts/test?message=Hello%20World"
media = await media_source.async_resolve_media(hass, media_id, None)
assert media.url.startswith("/api/tts_proxy/")
@ -115,7 +130,9 @@ async def test_legacy_resolving(
# Pass language and options
mock_get_tts_audio.reset_mock()
media_id = "media-source://tts/test?message=Bye%20World&language=de_DE&voice=Paulus"
media_id = (
f"media-source://tts/test?message=Bye%20World&language=de_DE{extra_options}"
)
media = await media_source.async_resolve_media(hass, media_id, None)
assert media.url.startswith("/api/tts_proxy/")
assert media.mime_type == "audio/mpeg"
@ -128,14 +145,24 @@ async def test_legacy_resolving(
assert mock_get_tts_audio.mock_calls[0][2]["options"] == {"voice": "Paulus"}
@pytest.mark.parametrize("mock_tts_entity", [MSEntity(DEFAULT_LANG)])
@pytest.mark.parametrize(
("mock_tts_entity", "extra_options"),
[
(MSEntity(DEFAULT_LANG), "&tts_options=%7B%22voice%22%3A%22Paulus%22%7D"),
(MSEntity(DEFAULT_LANG), "&voice=Paulus"),
],
)
async def test_resolving(
hass: HomeAssistant, hass_client: ClientSessionGenerator, mock_tts_entity: MSEntity
hass: HomeAssistant,
hass_client: ClientSessionGenerator,
mock_tts_entity: MSEntity,
extra_options: str,
) -> None:
"""Test resolving entity."""
await mock_config_entry_setup(hass, mock_tts_entity)
mock_get_tts_audio = mock_tts_entity.get_tts_audio
mock_get_tts_audio.reset_mock()
media_id = "media-source://tts/tts.test?message=Hello%20World"
media = await media_source.async_resolve_media(hass, media_id, None)
assert media.url.startswith("/api/tts_proxy/")
@ -151,7 +178,7 @@ async def test_resolving(
# Pass language and options
mock_get_tts_audio.reset_mock()
media_id = (
"media-source://tts/tts.test?message=Bye%20World&language=de_DE&voice=Paulus"
f"media-source://tts/tts.test?message=Bye%20World&language=de_DE{extra_options}"
)
media = await media_source.async_resolve_media(hass, media_id, None)
assert media.url.startswith("/api/tts_proxy/")
@ -191,6 +218,17 @@ async def test_resolving_errors(hass: HomeAssistant, setup: str, engine: str) ->
hass, "media-source://tts/non-existing?message=bla", None
)
# Non-JSON tts options
with pytest.raises(
media_source.Unresolvable,
match="Invalid TTS options: Expecting property name enclosed in double quotes",
):
await media_source.async_resolve_media(
hass,
f"media-source://tts/{engine}?message=bla&tts_options=%7Binvalid json",
None,
)
# Non-existing option
with pytest.raises(
media_source.Unresolvable,
@ -198,6 +236,69 @@ async def test_resolving_errors(hass: HomeAssistant, setup: str, engine: str) ->
):
await media_source.async_resolve_media(
hass,
f"media-source://tts/{engine}?message=bla&non_existing_option=bla",
f"media-source://tts/{engine}?message=bla&tts_options=%7B%22non_existing_option%22%3A%22bla%22%7D",
None,
)
@pytest.mark.parametrize(
("setup", "result_engine"),
[
("mock_setup", "test"),
("mock_config_entry_setup", "tts.test"),
],
indirect=["setup"],
)
async def test_generate_media_source_id_and_media_source_id_to_kwargs(
hass: HomeAssistant,
setup: str,
result_engine: str,
) -> None:
"""Test media_source_id and media_source_id_to_kwargs."""
kwargs: MediaSourceOptions = {
"engine": None,
"message": "hello",
"language": "en_US",
"options": {"age": 5},
"cache": True,
}
media_source_id = generate_media_source_id(hass, **kwargs)
assert media_source_id_to_kwargs(media_source_id) == {
"engine": result_engine,
"message": "hello",
"language": "en_US",
"options": {"age": 5},
"cache": True,
}
kwargs = {
"engine": None,
"message": "hello",
"language": "en_US",
"options": {"age": [5, 6]},
"cache": True,
}
media_source_id = generate_media_source_id(hass, **kwargs)
assert media_source_id_to_kwargs(media_source_id) == {
"engine": result_engine,
"message": "hello",
"language": "en_US",
"options": {"age": [5, 6]},
"cache": True,
}
kwargs = {
"engine": None,
"message": "hello",
"language": "en_US",
"options": {"age": {"k1": [5, 6], "k2": "v2"}},
"cache": True,
}
media_source_id = generate_media_source_id(hass, **kwargs)
assert media_source_id_to_kwargs(media_source_id) == {
"engine": result_engine,
"message": "hello",
"language": "en_US",
"options": {"age": {"k1": [5, 6], "k2": "v2"}},
"cache": True,
}