Validate TTS base url (#68212)

* Validate TTS base url

* Update tests/components/tts/test_init.py

Co-authored-by: Joakim Plate <elupus@ecce.se>

Co-authored-by: Joakim Plate <elupus@ecce.se>
This commit is contained in:
Paulus Schoutsen 2022-03-16 02:18:55 -07:00 committed by GitHub
parent 21aa07e3e5
commit 984e30075b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 53 additions and 1 deletions

View File

@ -17,6 +17,7 @@ from aiohttp import web
import mutagen import mutagen
from mutagen.id3 import ID3, TextFrame as ID3Text from mutagen.id3 import ID3, TextFrame as ID3Text
import voluptuous as vol import voluptuous as vol
import yarl
from homeassistant.components.http import HomeAssistantView from homeassistant.components.http import HomeAssistantView
from homeassistant.components.media_player.const import ( from homeassistant.components.media_player.const import (
@ -41,6 +42,7 @@ from homeassistant.helpers.network import get_url
from homeassistant.helpers.service import async_set_service_schema from homeassistant.helpers.service import async_set_service_schema
from homeassistant.helpers.typing import ConfigType, DiscoveryInfoType from homeassistant.helpers.typing import ConfigType, DiscoveryInfoType
from homeassistant.setup import async_prepare_setup_platform from homeassistant.setup import async_prepare_setup_platform
from homeassistant.util.network import normalize_url
from homeassistant.util.yaml import load_yaml from homeassistant.util.yaml import load_yaml
from .const import DOMAIN from .const import DOMAIN
@ -92,6 +94,16 @@ def _deprecated_platform(value):
return value return value
def valid_base_url(value: str) -> str:
"""Validate base url, return value."""
url = yarl.URL(cv.url(value))
if url.path != "/":
raise vol.Invalid("Path should be empty")
return normalize_url(value)
PLATFORM_SCHEMA = cv.PLATFORM_SCHEMA.extend( PLATFORM_SCHEMA = cv.PLATFORM_SCHEMA.extend(
{ {
vol.Required(CONF_PLATFORM): vol.All(cv.string, _deprecated_platform), vol.Required(CONF_PLATFORM): vol.All(cv.string, _deprecated_platform),
@ -100,7 +112,7 @@ PLATFORM_SCHEMA = cv.PLATFORM_SCHEMA.extend(
vol.Optional(CONF_TIME_MEMORY, default=DEFAULT_TIME_MEMORY): vol.All( vol.Optional(CONF_TIME_MEMORY, default=DEFAULT_TIME_MEMORY): vol.All(
vol.Coerce(int), vol.Range(min=60, max=57600) vol.Coerce(int), vol.Range(min=60, max=57600)
), ),
vol.Optional(CONF_BASE_URL): cv.string, vol.Optional(CONF_BASE_URL): valid_base_url,
vol.Optional(CONF_SERVICE_NAME): cv.string, vol.Optional(CONF_SERVICE_NAME): cv.string,
} }
) )

View File

@ -3,6 +3,7 @@ from http import HTTPStatus
from unittest.mock import PropertyMock, patch from unittest.mock import PropertyMock, patch
import pytest import pytest
import voluptuous as vol
import yarl import yarl
from homeassistant.components import tts from homeassistant.components import tts
@ -16,6 +17,7 @@ from homeassistant.components.media_player.const import (
) )
from homeassistant.config import async_process_ha_core_config from homeassistant.config import async_process_ha_core_config
from homeassistant.setup import async_setup_component from homeassistant.setup import async_setup_component
from homeassistant.util.network import normalize_url
from tests.common import assert_setup_component, async_mock_service from tests.common import assert_setup_component, async_mock_service
@ -689,3 +691,41 @@ async def test_tags_with_wave(hass, demo_provider):
) )
assert tagged_data != demo_data assert tagged_data != demo_data
@pytest.mark.parametrize(
"value",
(
"http://example.local:8123",
"http://example.local",
"http://example.local:80",
"https://example.com",
"https://example.com:443",
"https://example.com:8123",
),
)
def test_valid_base_url(value):
"""Test we validate base urls."""
assert tts.valid_base_url(value) == normalize_url(value)
# Test we strip trailing `/`
assert tts.valid_base_url(value + "/") == normalize_url(value)
@pytest.mark.parametrize(
"value",
(
"http://example.local:8123/sub-path",
"http://example.local/sub-path",
"https://example.com/sub-path",
"https://example.com:8123/sub-path",
"mailto:some@email",
"http:example.com",
"http:/example.com",
"http//example.com",
"example.com",
),
)
def test_invalid_base_url(value):
"""Test we catch bad base urls."""
with pytest.raises(vol.Invalid):
tts.valid_base_url(value)