diff --git a/homeassistant/components/file_upload/__init__.py b/homeassistant/components/file_upload/__init__.py index 9f548e14459..73f8465b1df 100644 --- a/homeassistant/components/file_upload/__init__.py +++ b/homeassistant/components/file_upload/__init__.py @@ -9,7 +9,8 @@ from pathlib import Path import shutil import tempfile -from aiohttp import web +from aiohttp import BodyPartReader, web +import janus import voluptuous as vol from homeassistant.components.http import HomeAssistantView @@ -22,9 +23,8 @@ from homeassistant.util.ulid import ulid_hex DOMAIN = "file_upload" -# If increased, change upload view to streaming -# https://docs.aiohttp.org/en/stable/web_quickstart.html#file-uploads -MAX_SIZE = 1024 * 1024 * 10 +ONE_MEGABYTE = 1024 * 1024 +MAX_SIZE = 100 * ONE_MEGABYTE TEMP_DIR_NAME = f"home-assistant-{DOMAIN}" @@ -126,14 +126,18 @@ class FileUploadView(HomeAssistantView): # Increase max payload request._client_max_size = MAX_SIZE # pylint: disable=protected-access - data = await request.post() - file_field = data.get("file") + reader = await request.multipart() + file_field_reader = await reader.next() - if not isinstance(file_field, web.FileField): + if ( + not isinstance(file_field_reader, BodyPartReader) + or file_field_reader.name != "file" + or file_field_reader.filename is None + ): raise vol.Invalid("Expected a file") try: - raise_if_invalid_filename(file_field.filename) + raise_if_invalid_filename(file_field_reader.filename) except ValueError as err: raise web.HTTPBadRequest from err @@ -145,19 +149,39 @@ class FileUploadView(HomeAssistantView): file_upload_data: FileUploadData = hass.data[DOMAIN] file_dir = file_upload_data.file_dir(file_id) + queue: janus.Queue[bytes | None] = janus.Queue() - def _sync_work() -> None: + def _sync_queue_consumer( + sync_q: janus.SyncQueue[bytes | None], _file_name: str + ) -> None: file_dir.mkdir() + with (file_dir / _file_name).open("wb") as file_handle: + while True: + _chunk = sync_q.get() + if _chunk is None: + break - # MyPy forgets about the isinstance check because we're in a function scope - assert isinstance(file_field, web.FileField) + file_handle.write(_chunk) + sync_q.task_done() - with (file_dir / file_field.filename).open("wb") as target_fileobj: - shutil.copyfileobj(file_field.file, target_fileobj) + fut: asyncio.Future[None] | None = None + try: + fut = hass.async_add_executor_job( + _sync_queue_consumer, + queue.sync_q, + file_field_reader.filename, + ) - await hass.async_add_executor_job(_sync_work) + while chunk := await file_field_reader.read_chunk(ONE_MEGABYTE): + queue.async_q.put_nowait(chunk) + if queue.async_q.qsize() > 5: # Allow up to 5 MB buffer size + await queue.async_q.join() + queue.async_q.put_nowait(None) # terminate queue consumer + finally: + if fut is not None: + await fut - file_upload_data.files[file_id] = file_field.filename + file_upload_data.files[file_id] = file_field_reader.filename return self.json({"file_id": file_id}) diff --git a/homeassistant/components/file_upload/manifest.json b/homeassistant/components/file_upload/manifest.json index d2b4f88a279..62f7a1f2b27 100644 --- a/homeassistant/components/file_upload/manifest.json +++ b/homeassistant/components/file_upload/manifest.json @@ -2,6 +2,7 @@ "domain": "file_upload", "name": "File Upload", "documentation": "https://www.home-assistant.io/integrations/file_upload", + "requirements": ["janus==1.0.0"], "dependencies": ["http"], "codeowners": ["@home-assistant/core"], "quality_scale": "internal", diff --git a/homeassistant/package_constraints.txt b/homeassistant/package_constraints.txt index dbd749d7ca7..7958b776f10 100644 --- a/homeassistant/package_constraints.txt +++ b/homeassistant/package_constraints.txt @@ -25,6 +25,7 @@ home-assistant-bluetooth==1.8.1 home-assistant-frontend==20221108.0 httpx==0.23.1 ifaddr==0.1.7 +janus==1.0.0 jinja2==3.1.2 lru-dict==1.1.8 orjson==3.8.1 diff --git a/requirements_all.txt b/requirements_all.txt index a74b4713fcd..ac8bbd9b144 100644 --- a/requirements_all.txt +++ b/requirements_all.txt @@ -964,6 +964,9 @@ iperf3==0.1.11 # homeassistant.components.gogogate2 ismartgate==4.0.4 +# homeassistant.components.file_upload +janus==1.0.0 + # homeassistant.components.jellyfin jellyfin-apiclient-python==1.9.2 diff --git a/requirements_test_all.txt b/requirements_test_all.txt index 98e255f30f8..c62ca320ee0 100644 --- a/requirements_test_all.txt +++ b/requirements_test_all.txt @@ -717,6 +717,9 @@ iotawattpy==0.1.0 # homeassistant.components.gogogate2 ismartgate==4.0.4 +# homeassistant.components.file_upload +janus==1.0.0 + # homeassistant.components.jellyfin jellyfin-apiclient-python==1.9.2 diff --git a/tests/components/file_upload/conftest.py b/tests/components/file_upload/conftest.py new file mode 100644 index 00000000000..ab9965c1914 --- /dev/null +++ b/tests/components/file_upload/conftest.py @@ -0,0 +1,13 @@ +"""Fixtures for FileUpload integration.""" +from io import StringIO + +import pytest + + +@pytest.fixture +def large_file_io() -> StringIO: + """Generate a file on the fly. Simulates a large file.""" + return StringIO( + 2 + * "Home Assistant is awesome. Open source home automation that puts local control and privacy first." + ) diff --git a/tests/components/file_upload/test_init.py b/tests/components/file_upload/test_init.py index ba3485c96e1..699fb6f9b84 100644 --- a/tests/components/file_upload/test_init.py +++ b/tests/components/file_upload/test_init.py @@ -64,3 +64,49 @@ async def test_removed_on_stop(hass: HomeAssistant, hass_client, uploaded_file_d # Test it's removed assert not uploaded_file_dir.exists() + + +async def test_upload_large_file(hass: HomeAssistant, hass_client, large_file_io): + """Test uploading large file.""" + assert await async_setup_component(hass, "file_upload", {}) + client = await hass_client() + + with patch( + # Patch temp dir name to avoid tests fail running in parallel + "homeassistant.components.file_upload.TEMP_DIR_NAME", + file_upload.TEMP_DIR_NAME + f"-{getrandbits(10):03x}", + ), patch( + # Patch one megabyte to 8 bytes to prevent having to use big files in tests + "homeassistant.components.file_upload.ONE_MEGABYTE", + 8, + ): + res = await client.post("/api/file_upload", data={"file": large_file_io}) + + assert res.status == 200 + response = await res.json() + + file_dir = hass.data[file_upload.DOMAIN].file_dir(response["file_id"]) + assert file_dir.is_dir() + + large_file_io.seek(0) + with file_upload.process_uploaded_file(hass, file_dir.name) as file_path: + assert file_path.is_file() + assert file_path.parent == file_dir + assert file_path.read_bytes() == large_file_io.read().encode("utf-8") + + +async def test_upload_with_wrong_key_fails( + hass: HomeAssistant, hass_client, large_file_io +): + """Test uploading fails.""" + assert await async_setup_component(hass, "file_upload", {}) + client = await hass_client() + + with patch( + # Patch temp dir name to avoid tests fail running in parallel + "homeassistant.components.file_upload.TEMP_DIR_NAME", + file_upload.TEMP_DIR_NAME + f"-{getrandbits(10):03x}", + ): + res = await client.post("/api/file_upload", data={"wrong_key": large_file_io}) + + assert res.status == 400 diff --git a/tests/test_requirements.py b/tests/test_requirements.py index 826150e47c4..c10b49aa110 100644 --- a/tests/test_requirements.py +++ b/tests/test_requirements.py @@ -401,7 +401,7 @@ async def test_discovery_requirements_mqtt(hass): ) as mock_process: await async_get_integration_with_requirements(hass, "mqtt_comp") - assert len(mock_process.mock_calls) == 2 # mqtt also depends on http + assert len(mock_process.mock_calls) == 3 # mqtt also depends on http assert mock_process.mock_calls[0][1][1] == mqtt.requirements