diff --git a/homeassistant/components/tts/media_source.py b/homeassistant/components/tts/media_source.py index 13c37681259..dce521621c5 100644 --- a/homeassistant/components/tts/media_source.py +++ b/homeassistant/components/tts/media_source.py @@ -2,6 +2,7 @@ from __future__ import annotations +import json import mimetypes from typing import TypedDict @@ -22,6 +23,8 @@ from homeassistant.exceptions import HomeAssistantError from .const import DATA_TTS_MANAGER, DOMAIN, DOMAIN_DATA from .helper import get_engine_instance +URL_QUERY_TTS_OPTIONS = "tts_options" + async def async_get_media_source(hass: HomeAssistant) -> TTSMediaSource: """Set up tts media source.""" @@ -55,8 +58,7 @@ def generate_media_source_id( params["cache"] = "true" if cache else "false" if language is not None: params["language"] = language - if options is not None: - params.update(options) + params[URL_QUERY_TTS_OPTIONS] = json.dumps(options, separators=(",", ":")) return ms_generate_media_source_id( DOMAIN, @@ -78,19 +80,28 @@ class MediaSourceOptions(TypedDict): def media_source_id_to_kwargs(media_source_id: str) -> MediaSourceOptions: """Turn a media source ID into options.""" parsed = URL(media_source_id) + if URL_QUERY_TTS_OPTIONS in parsed.query: + try: + options = json.loads(parsed.query[URL_QUERY_TTS_OPTIONS]) + except json.JSONDecodeError as err: + raise Unresolvable(f"Invalid TTS options: {err.msg}") from err + else: + options = { + k: v + for k, v in parsed.query.items() + if k not in ("message", "language", "cache") + } if "message" not in parsed.query: raise Unresolvable("No message specified.") - - options = dict(parsed.query) kwargs: MediaSourceOptions = { "engine": parsed.name, - "message": options.pop("message"), - "language": options.pop("language", None), + "message": parsed.query["message"], + "language": parsed.query.get("language"), "options": options, "cache": None, } - if "cache" in options: - kwargs["cache"] = options.pop("cache") == "true" + if "cache" in parsed.query: + kwargs["cache"] = parsed.query["cache"] == "true" return kwargs @@ -111,6 +122,8 @@ class TTSMediaSource(MediaSource): url = await self.hass.data[DATA_TTS_MANAGER].async_get_url_path( **media_source_id_to_kwargs(item.identifier) ) + except Unresolvable: + raise except HomeAssistantError as err: raise Unresolvable(str(err)) from err diff --git a/tests/components/assist_pipeline/snapshots/test_init.ambr b/tests/components/assist_pipeline/snapshots/test_init.ambr index 7f29534e473..e14bbac1839 100644 --- a/tests/components/assist_pipeline/snapshots/test_init.ambr +++ b/tests/components/assist_pipeline/snapshots/test_init.ambr @@ -75,7 +75,7 @@ dict({ 'data': dict({ 'tts_output': dict({ - 'media_id': "media-source://tts/test?message=Sorry,+I+couldn't+understand+that&language=en-US&voice=james_earl_jones", + 'media_id': "media-source://tts/test?message=Sorry,+I+couldn't+understand+that&language=en-US&tts_options=%7B%22voice%22:%22james_earl_jones%22%7D", 'mime_type': 'audio/mpeg', 'url': '/api/tts_proxy/dae2cdcb27a1d1c3b07ba2c7db91480f9d4bfd8f_en-us_031e2ec052_test.mp3', }), @@ -164,7 +164,7 @@ dict({ 'data': dict({ 'tts_output': dict({ - 'media_id': "media-source://tts/test?message=Sorry,+I+couldn't+understand+that&language=en-US&voice=Arnold+Schwarzenegger", + 'media_id': "media-source://tts/test?message=Sorry,+I+couldn't+understand+that&language=en-US&tts_options=%7B%22voice%22:%22Arnold+Schwarzenegger%22%7D", 'mime_type': 'audio/mpeg', 'url': '/api/tts_proxy/dae2cdcb27a1d1c3b07ba2c7db91480f9d4bfd8f_en-us_2657c1a8ee_test.mp3', }), @@ -253,7 +253,7 @@ dict({ 'data': dict({ 'tts_output': dict({ - 'media_id': "media-source://tts/test?message=Sorry,+I+couldn't+understand+that&language=en-US&voice=Arnold+Schwarzenegger", + 'media_id': "media-source://tts/test?message=Sorry,+I+couldn't+understand+that&language=en-US&tts_options=%7B%22voice%22:%22Arnold+Schwarzenegger%22%7D", 'mime_type': 'audio/mpeg', 'url': '/api/tts_proxy/dae2cdcb27a1d1c3b07ba2c7db91480f9d4bfd8f_en-us_2657c1a8ee_test.mp3', }), @@ -366,7 +366,7 @@ dict({ 'data': dict({ 'tts_output': dict({ - 'media_id': "media-source://tts/test?message=Sorry,+I+couldn't+understand+that&language=en-US&voice=james_earl_jones", + 'media_id': "media-source://tts/test?message=Sorry,+I+couldn't+understand+that&language=en-US&tts_options=%7B%22voice%22:%22james_earl_jones%22%7D", 'mime_type': 'audio/mpeg', 'url': '/api/tts_proxy/dae2cdcb27a1d1c3b07ba2c7db91480f9d4bfd8f_en-us_031e2ec052_test.mp3', }), diff --git a/tests/components/assist_pipeline/snapshots/test_websocket.ambr b/tests/components/assist_pipeline/snapshots/test_websocket.ambr index 7ea6af7e0bd..131444c17ac 100644 --- a/tests/components/assist_pipeline/snapshots/test_websocket.ambr +++ b/tests/components/assist_pipeline/snapshots/test_websocket.ambr @@ -71,7 +71,7 @@ # name: test_audio_pipeline.6 dict({ 'tts_output': dict({ - 'media_id': "media-source://tts/test?message=Sorry,+I+couldn't+understand+that&language=en-US&voice=james_earl_jones", + 'media_id': "media-source://tts/test?message=Sorry,+I+couldn't+understand+that&language=en-US&tts_options=%7B%22voice%22:%22james_earl_jones%22%7D", 'mime_type': 'audio/mpeg', 'url': '/api/tts_proxy/dae2cdcb27a1d1c3b07ba2c7db91480f9d4bfd8f_en-us_031e2ec052_test.mp3', }), @@ -152,7 +152,7 @@ # name: test_audio_pipeline_debug.6 dict({ 'tts_output': dict({ - 'media_id': "media-source://tts/test?message=Sorry,+I+couldn't+understand+that&language=en-US&voice=james_earl_jones", + 'media_id': "media-source://tts/test?message=Sorry,+I+couldn't+understand+that&language=en-US&tts_options=%7B%22voice%22:%22james_earl_jones%22%7D", 'mime_type': 'audio/mpeg', 'url': '/api/tts_proxy/dae2cdcb27a1d1c3b07ba2c7db91480f9d4bfd8f_en-us_031e2ec052_test.mp3', }), @@ -245,7 +245,7 @@ # name: test_audio_pipeline_with_enhancements.6 dict({ 'tts_output': dict({ - 'media_id': "media-source://tts/test?message=Sorry,+I+couldn't+understand+that&language=en-US&voice=james_earl_jones", + 'media_id': "media-source://tts/test?message=Sorry,+I+couldn't+understand+that&language=en-US&tts_options=%7B%22voice%22:%22james_earl_jones%22%7D", 'mime_type': 'audio/mpeg', 'url': '/api/tts_proxy/dae2cdcb27a1d1c3b07ba2c7db91480f9d4bfd8f_en-us_031e2ec052_test.mp3', }), @@ -348,7 +348,7 @@ # name: test_audio_pipeline_with_wake_word_no_timeout.8 dict({ 'tts_output': dict({ - 'media_id': "media-source://tts/test?message=Sorry,+I+couldn't+understand+that&language=en-US&voice=james_earl_jones", + 'media_id': "media-source://tts/test?message=Sorry,+I+couldn't+understand+that&language=en-US&tts_options=%7B%22voice%22:%22james_earl_jones%22%7D", 'mime_type': 'audio/mpeg', 'url': '/api/tts_proxy/dae2cdcb27a1d1c3b07ba2c7db91480f9d4bfd8f_en-us_031e2ec052_test.mp3', }), diff --git a/tests/components/tts/test_init.py b/tests/components/tts/test_init.py index cf04fbb175b..2ab6dc16629 100644 --- a/tests/components/tts/test_init.py +++ b/tests/components/tts/test_init.py @@ -1318,10 +1318,16 @@ async def test_tags_with_wave() -> None: @pytest.mark.parametrize( ("engine", "language", "options", "cache", "result_query"), [ - (None, None, None, None, ""), - (None, "de_DE", None, None, "language=de_DE"), - (None, "de_DE", {"voice": "henk"}, None, "language=de_DE&voice=henk"), - (None, "de_DE", None, True, "cache=true&language=de_DE"), + (None, None, None, None, "&tts_options=null"), + (None, "de_DE", None, None, "&language=de_DE&tts_options=null"), + ( + None, + "de_DE", + {"voice": "henk"}, + None, + "&language=de_DE&tts_options=%7B%22voice%22:%22henk%22%7D", + ), + (None, "de_DE", None, True, "&cache=true&language=de_DE&tts_options=null"), ], ) async def test_generate_media_source_id( @@ -1343,8 +1349,9 @@ async def test_generate_media_source_id( _, _, engine_query = media_source_id.rpartition("/") engine, _, query = engine_query.partition("?") assert engine == result_engine - assert query.startswith("message=msg") - assert query[12:] == result_query + query_prefix = "message=msg" + assert query.startswith(query_prefix) + assert query[len(query_prefix) :] == result_query @pytest.mark.parametrize( diff --git a/tests/components/tts/test_media_source.py b/tests/components/tts/test_media_source.py index 367b24dd4d0..d90923b02ab 100644 --- a/tests/components/tts/test_media_source.py +++ b/tests/components/tts/test_media_source.py @@ -8,6 +8,11 @@ import pytest from homeassistant.components import media_source from homeassistant.components.media_player import BrowseError +from homeassistant.components.tts.media_source import ( + MediaSourceOptions, + generate_media_source_id, + media_source_id_to_kwargs, +) from homeassistant.core import HomeAssistant from homeassistant.setup import async_setup_component @@ -93,14 +98,24 @@ async def test_browsing(hass: HomeAssistant, setup: str) -> None: await media_source.async_browse_media(hass, "media-source://tts/non-existing") -@pytest.mark.parametrize("mock_provider", [MSProvider(DEFAULT_LANG)]) +@pytest.mark.parametrize( + ("mock_provider", "extra_options"), + [ + (MSProvider(DEFAULT_LANG), "&tts_options=%7B%22voice%22%3A%22Paulus%22%7D"), + (MSProvider(DEFAULT_LANG), "&voice=Paulus"), + ], +) async def test_legacy_resolving( - hass: HomeAssistant, hass_client: ClientSessionGenerator, mock_provider: MSProvider + hass: HomeAssistant, + hass_client: ClientSessionGenerator, + mock_provider: MSProvider, + extra_options: str, ) -> None: """Test resolving legacy provider.""" await mock_setup(hass, mock_provider) mock_get_tts_audio = mock_provider.get_tts_audio + mock_get_tts_audio.reset_mock() media_id = "media-source://tts/test?message=Hello%20World" media = await media_source.async_resolve_media(hass, media_id, None) assert media.url.startswith("/api/tts_proxy/") @@ -115,7 +130,9 @@ async def test_legacy_resolving( # Pass language and options mock_get_tts_audio.reset_mock() - media_id = "media-source://tts/test?message=Bye%20World&language=de_DE&voice=Paulus" + media_id = ( + f"media-source://tts/test?message=Bye%20World&language=de_DE{extra_options}" + ) media = await media_source.async_resolve_media(hass, media_id, None) assert media.url.startswith("/api/tts_proxy/") assert media.mime_type == "audio/mpeg" @@ -128,14 +145,24 @@ async def test_legacy_resolving( assert mock_get_tts_audio.mock_calls[0][2]["options"] == {"voice": "Paulus"} -@pytest.mark.parametrize("mock_tts_entity", [MSEntity(DEFAULT_LANG)]) +@pytest.mark.parametrize( + ("mock_tts_entity", "extra_options"), + [ + (MSEntity(DEFAULT_LANG), "&tts_options=%7B%22voice%22%3A%22Paulus%22%7D"), + (MSEntity(DEFAULT_LANG), "&voice=Paulus"), + ], +) async def test_resolving( - hass: HomeAssistant, hass_client: ClientSessionGenerator, mock_tts_entity: MSEntity + hass: HomeAssistant, + hass_client: ClientSessionGenerator, + mock_tts_entity: MSEntity, + extra_options: str, ) -> None: """Test resolving entity.""" await mock_config_entry_setup(hass, mock_tts_entity) mock_get_tts_audio = mock_tts_entity.get_tts_audio + mock_get_tts_audio.reset_mock() media_id = "media-source://tts/tts.test?message=Hello%20World" media = await media_source.async_resolve_media(hass, media_id, None) assert media.url.startswith("/api/tts_proxy/") @@ -151,7 +178,7 @@ async def test_resolving( # Pass language and options mock_get_tts_audio.reset_mock() media_id = ( - "media-source://tts/tts.test?message=Bye%20World&language=de_DE&voice=Paulus" + f"media-source://tts/tts.test?message=Bye%20World&language=de_DE{extra_options}" ) media = await media_source.async_resolve_media(hass, media_id, None) assert media.url.startswith("/api/tts_proxy/") @@ -191,6 +218,17 @@ async def test_resolving_errors(hass: HomeAssistant, setup: str, engine: str) -> hass, "media-source://tts/non-existing?message=bla", None ) + # Non-JSON tts options + with pytest.raises( + media_source.Unresolvable, + match="Invalid TTS options: Expecting property name enclosed in double quotes", + ): + await media_source.async_resolve_media( + hass, + f"media-source://tts/{engine}?message=bla&tts_options=%7Binvalid json", + None, + ) + # Non-existing option with pytest.raises( media_source.Unresolvable, @@ -198,6 +236,69 @@ async def test_resolving_errors(hass: HomeAssistant, setup: str, engine: str) -> ): await media_source.async_resolve_media( hass, - f"media-source://tts/{engine}?message=bla&non_existing_option=bla", + f"media-source://tts/{engine}?message=bla&tts_options=%7B%22non_existing_option%22%3A%22bla%22%7D", None, ) + + +@pytest.mark.parametrize( + ("setup", "result_engine"), + [ + ("mock_setup", "test"), + ("mock_config_entry_setup", "tts.test"), + ], + indirect=["setup"], +) +async def test_generate_media_source_id_and_media_source_id_to_kwargs( + hass: HomeAssistant, + setup: str, + result_engine: str, +) -> None: + """Test media_source_id and media_source_id_to_kwargs.""" + kwargs: MediaSourceOptions = { + "engine": None, + "message": "hello", + "language": "en_US", + "options": {"age": 5}, + "cache": True, + } + media_source_id = generate_media_source_id(hass, **kwargs) + assert media_source_id_to_kwargs(media_source_id) == { + "engine": result_engine, + "message": "hello", + "language": "en_US", + "options": {"age": 5}, + "cache": True, + } + + kwargs = { + "engine": None, + "message": "hello", + "language": "en_US", + "options": {"age": [5, 6]}, + "cache": True, + } + media_source_id = generate_media_source_id(hass, **kwargs) + assert media_source_id_to_kwargs(media_source_id) == { + "engine": result_engine, + "message": "hello", + "language": "en_US", + "options": {"age": [5, 6]}, + "cache": True, + } + + kwargs = { + "engine": None, + "message": "hello", + "language": "en_US", + "options": {"age": {"k1": [5, 6], "k2": "v2"}}, + "cache": True, + } + media_source_id = generate_media_source_id(hass, **kwargs) + assert media_source_id_to_kwargs(media_source_id) == { + "engine": result_engine, + "message": "hello", + "language": "en_US", + "options": {"age": {"k1": [5, 6], "k2": "v2"}}, + "cache": True, + }