From dd48f1e6fc2e170e65c00dfd20bcd28fedb2f4bd Mon Sep 17 00:00:00 2001 From: Paulus Schoutsen Date: Thu, 10 Feb 2022 08:03:14 -0800 Subject: [PATCH] Allow uploading media to media folder (#66143) --- .../components/media_source/local_source.py | 101 ++++++++++++++- .../media_source/test_local_source.py | 122 ++++++++++++++++++ 2 files changed, 221 insertions(+), 2 deletions(-) diff --git a/homeassistant/components/media_source/local_source.py b/homeassistant/components/media_source/local_source.py index 7213d6ac7a0..f43bb3d97c7 100644 --- a/homeassistant/components/media_source/local_source.py +++ b/homeassistant/components/media_source/local_source.py @@ -1,21 +1,29 @@ """Local Media Source Implementation.""" from __future__ import annotations +import logging import mimetypes from pathlib import Path from aiohttp import web +from aiohttp.web_request import FileField +from aioshutil import shutil +import voluptuous as vol from homeassistant.components.http import HomeAssistantView from homeassistant.components.media_player.const import MEDIA_CLASS_DIRECTORY from homeassistant.components.media_player.errors import BrowseError from homeassistant.core import HomeAssistant, callback -from homeassistant.util import raise_if_invalid_path +from homeassistant.exceptions import Unauthorized +from homeassistant.util import raise_if_invalid_filename, raise_if_invalid_path from .const import DOMAIN, MEDIA_CLASS_MAP, MEDIA_MIME_TYPES from .error import Unresolvable from .models import BrowseMediaSource, MediaSource, MediaSourceItem, PlayMedia +MAX_UPLOAD_SIZE = 1024 * 1024 * 10 +LOGGER = logging.getLogger(__name__) + @callback def async_setup(hass: HomeAssistant) -> None: @@ -23,6 +31,7 @@ def async_setup(hass: HomeAssistant) -> None: source = LocalSource(hass) hass.data[DOMAIN][DOMAIN] = source hass.http.register_view(LocalMediaView(hass, source)) + hass.http.register_view(UploadMediaView(hass, source)) class LocalSource(MediaSource): @@ -43,11 +52,14 @@ class LocalSource(MediaSource): @callback def async_parse_identifier(self, item: MediaSourceItem) -> tuple[str, str]: """Parse identifier.""" + if item.domain != DOMAIN: + raise Unresolvable("Unknown domain.") + if not item.identifier: # Empty source_dir_id and location return "", "" - source_dir_id, location = item.identifier.split("/", 1) + source_dir_id, _, location = item.identifier.partition("/") if source_dir_id not in self.hass.config.media_dirs: raise Unresolvable("Unknown source directory.") @@ -217,3 +229,88 @@ class LocalMediaView(HomeAssistantView): raise web.HTTPNotFound() return web.FileResponse(media_path) + + +class UploadMediaView(HomeAssistantView): + """View to upload images.""" + + url = "/api/media_source/local_source/upload" + name = "api:media_source:local_source:upload" + + def __init__(self, hass: HomeAssistant, source: LocalSource) -> None: + """Initialize the media view.""" + self.hass = hass + self.source = source + self.schema = vol.Schema( + { + "media_content_id": str, + "file": FileField, + } + ) + + async def post(self, request: web.Request) -> web.Response: + """Handle upload.""" + if not request["hass_user"].is_admin: + raise Unauthorized() + + # Increase max payload + request._client_max_size = MAX_UPLOAD_SIZE # pylint: disable=protected-access + + try: + data = self.schema(dict(await request.post())) + except vol.Invalid as err: + LOGGER.error("Received invalid upload data: %s", err) + raise web.HTTPBadRequest() from err + + try: + item = MediaSourceItem.from_uri(self.hass, data["media_content_id"]) + except ValueError as err: + LOGGER.error("Received invalid upload data: %s", err) + raise web.HTTPBadRequest() from err + + try: + source_dir_id, location = self.source.async_parse_identifier(item) + except Unresolvable as err: + LOGGER.error("Invalid local source ID") + raise web.HTTPBadRequest() from err + + uploaded_file: FileField = data["file"] + + if not uploaded_file.content_type.startswith(("image/", "video/")): + LOGGER.error("Content type not allowed") + raise vol.Invalid("Only images and video are allowed") + + try: + raise_if_invalid_filename(uploaded_file.filename) + except ValueError as err: + LOGGER.error("Invalid filename") + raise web.HTTPBadRequest() from err + + try: + await self.hass.async_add_executor_job( + self._move_file, + self.source.async_full_path(source_dir_id, location), + uploaded_file, + ) + except ValueError as err: + LOGGER.error("Moving upload failed: %s", err) + raise web.HTTPBadRequest() from err + + return self.json( + {"media_content_id": f"{data['media_content_id']}/{uploaded_file.filename}"} + ) + + def _move_file( # pylint: disable=no-self-use + self, target_dir: Path, uploaded_file: FileField + ) -> None: + """Move file to target.""" + if not target_dir.is_dir(): + raise ValueError("Target is not an existing directory") + + target_path = target_dir / uploaded_file.filename + + target_path.relative_to(target_dir) + raise_if_invalid_path(str(target_path)) + + with target_path.open("wb") as target_fp: + shutil.copyfileobj(uploaded_file.file, target_fp) diff --git a/tests/components/media_source/test_local_source.py b/tests/components/media_source/test_local_source.py index 8a9005d7a86..f9ee560620c 100644 --- a/tests/components/media_source/test_local_source.py +++ b/tests/components/media_source/test_local_source.py @@ -1,5 +1,9 @@ """Test Local Media Source.""" from http import HTTPStatus +import io +from pathlib import Path +from tempfile import TemporaryDirectory +from unittest.mock import patch import pytest @@ -9,6 +13,20 @@ from homeassistant.config import async_process_ha_core_config from homeassistant.setup import async_setup_component +@pytest.fixture +async def temp_dir(hass): + """Return a temp dir.""" + with TemporaryDirectory() as tmpdirname: + target_dir = Path(tmpdirname) / "another_subdir" + target_dir.mkdir() + await async_process_ha_core_config( + hass, {"media_dirs": {"test_dir": str(target_dir)}} + ) + assert await async_setup_component(hass, const.DOMAIN, {}) + + yield str(target_dir) + + async def test_async_browse_media(hass): """Test browse media.""" local_media = hass.config.path("media") @@ -102,3 +120,107 @@ async def test_media_view(hass, hass_client): resp = await client.get("/media/recordings/test.mp3") assert resp.status == HTTPStatus.OK + + +async def test_upload_view(hass, hass_client, temp_dir, hass_admin_user): + """Allow uploading media.""" + + img = (Path(__file__).parent.parent / "image/logo.png").read_bytes() + + def get_file(name): + pic = io.BytesIO(img) + pic.name = name + return pic + + client = await hass_client() + + # Test normal upload + res = await client.post( + "/api/media_source/local_source/upload", + data={ + "media_content_id": "media-source://media_source/test_dir/.", + "file": get_file("logo.png"), + }, + ) + + assert res.status == 200 + assert (Path(temp_dir) / "logo.png").is_file() + + # Test with bad media source ID + for bad_id in ( + # Subdir doesn't exist + "media-source://media_source/test_dir/some-other-dir", + # Main dir doesn't exist + "media-source://media_source/test_dir2", + # Location is invalid + "media-source://media_source/test_dir/..", + # Domain != media_source + "media-source://nest/test_dir/.", + # Completely something else + "http://bla", + ): + res = await client.post( + "/api/media_source/local_source/upload", + data={ + "media_content_id": bad_id, + "file": get_file("bad-source-id.png"), + }, + ) + + assert res.status == 400 + assert not (Path(temp_dir) / "bad-source-id.png").is_file() + + # Test invalid POST data + res = await client.post( + "/api/media_source/local_source/upload", + data={ + "media_content_id": "media-source://media_source/test_dir/.", + "file": get_file("invalid-data.png"), + "incorrect": "format", + }, + ) + + assert res.status == 400 + assert not (Path(temp_dir) / "invalid-data.png").is_file() + + # Test invalid content type + text_file = io.BytesIO(b"Hello world") + text_file.name = "hello.txt" + res = await client.post( + "/api/media_source/local_source/upload", + data={ + "media_content_id": "media-source://media_source/test_dir/.", + "file": text_file, + }, + ) + + assert res.status == 400 + assert not (Path(temp_dir) / "hello.txt").is_file() + + # Test invalid filename + with patch( + "aiohttp.formdata.guess_filename", return_value="../invalid-filename.png" + ): + res = await client.post( + "/api/media_source/local_source/upload", + data={ + "media_content_id": "media-source://media_source/test_dir/.", + "file": get_file("../invalid-filename.png"), + }, + ) + + assert res.status == 400 + assert not (Path(temp_dir) / "../invalid-filename.png").is_file() + + # Remove admin access + hass_admin_user.groups = [] + res = await client.post( + "/api/media_source/local_source/upload", + data={ + "media_content_id": "media-source://media_source/test_dir/.", + "file": get_file("no-admin-test.png"), + }, + ) + + assert res.status == 401 + assert not (Path(temp_dir) / "no-admin-test.png").is_file()