From 984e30075bae1bb757eb1cf62bc4f87d0f37e0c2 Mon Sep 17 00:00:00 2001 From: Paulus Schoutsen Date: Wed, 16 Mar 2022 02:18:55 -0700 Subject: [PATCH] Validate TTS base url (#68212) * Validate TTS base url * Update tests/components/tts/test_init.py Co-authored-by: Joakim Plate Co-authored-by: Joakim Plate --- homeassistant/components/tts/__init__.py | 14 ++++++++- tests/components/tts/test_init.py | 40 ++++++++++++++++++++++++ 2 files changed, 53 insertions(+), 1 deletion(-) diff --git a/homeassistant/components/tts/__init__.py b/homeassistant/components/tts/__init__.py index 5e6629ca2a2..c001fb6b89b 100644 --- a/homeassistant/components/tts/__init__.py +++ b/homeassistant/components/tts/__init__.py @@ -17,6 +17,7 @@ from aiohttp import web import mutagen from mutagen.id3 import ID3, TextFrame as ID3Text import voluptuous as vol +import yarl from homeassistant.components.http import HomeAssistantView from homeassistant.components.media_player.const import ( @@ -41,6 +42,7 @@ from homeassistant.helpers.network import get_url from homeassistant.helpers.service import async_set_service_schema from homeassistant.helpers.typing import ConfigType, DiscoveryInfoType from homeassistant.setup import async_prepare_setup_platform +from homeassistant.util.network import normalize_url from homeassistant.util.yaml import load_yaml from .const import DOMAIN @@ -92,6 +94,16 @@ def _deprecated_platform(value): return value +def valid_base_url(value: str) -> str: + """Validate base url, return value.""" + url = yarl.URL(cv.url(value)) + + if url.path != "/": + raise vol.Invalid("Path should be empty") + + return normalize_url(value) + + PLATFORM_SCHEMA = cv.PLATFORM_SCHEMA.extend( { vol.Required(CONF_PLATFORM): vol.All(cv.string, _deprecated_platform), @@ -100,7 +112,7 @@ PLATFORM_SCHEMA = cv.PLATFORM_SCHEMA.extend( vol.Optional(CONF_TIME_MEMORY, default=DEFAULT_TIME_MEMORY): vol.All( vol.Coerce(int), vol.Range(min=60, max=57600) ), - vol.Optional(CONF_BASE_URL): cv.string, + vol.Optional(CONF_BASE_URL): valid_base_url, vol.Optional(CONF_SERVICE_NAME): cv.string, } ) diff --git a/tests/components/tts/test_init.py b/tests/components/tts/test_init.py index 9f1cc849a1f..5543e2d82f5 100644 --- a/tests/components/tts/test_init.py +++ b/tests/components/tts/test_init.py @@ -3,6 +3,7 @@ from http import HTTPStatus from unittest.mock import PropertyMock, patch import pytest +import voluptuous as vol import yarl from homeassistant.components import tts @@ -16,6 +17,7 @@ from homeassistant.components.media_player.const import ( ) from homeassistant.config import async_process_ha_core_config from homeassistant.setup import async_setup_component +from homeassistant.util.network import normalize_url from tests.common import assert_setup_component, async_mock_service @@ -689,3 +691,41 @@ async def test_tags_with_wave(hass, demo_provider): ) assert tagged_data != demo_data + + +@pytest.mark.parametrize( + "value", + ( + "http://example.local:8123", + "http://example.local", + "http://example.local:80", + "https://example.com", + "https://example.com:443", + "https://example.com:8123", + ), +) +def test_valid_base_url(value): + """Test we validate base urls.""" + assert tts.valid_base_url(value) == normalize_url(value) + # Test we strip trailing `/` + assert tts.valid_base_url(value + "/") == normalize_url(value) + + +@pytest.mark.parametrize( + "value", + ( + "http://example.local:8123/sub-path", + "http://example.local/sub-path", + "https://example.com/sub-path", + "https://example.com:8123/sub-path", + "mailto:some@email", + "http:example.com", + "http:/example.com", + "http//example.com", + "example.com", + ), +) +def test_invalid_base_url(value): + """Test we catch bad base urls.""" + with pytest.raises(vol.Invalid): + tts.valid_base_url(value)