Refactor file_upload to avoid janus dep (#112032)

This commit is contained in:
J. Nick Koston 2024-03-02 10:58:08 -10:00 committed by GitHub
parent dca6104b4b
commit 546fc1e282
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 72 additions and 33 deletions

View File

@ -6,11 +6,11 @@ from collections.abc import Iterator
from contextlib import contextmanager from contextlib import contextmanager
from dataclasses import dataclass from dataclasses import dataclass
from pathlib import Path from pathlib import Path
from queue import SimpleQueue
import shutil import shutil
import tempfile import tempfile
from aiohttp import BodyPartReader, web from aiohttp import BodyPartReader, web
import janus
import voluptuous as vol import voluptuous as vol
from homeassistant.components.http import HomeAssistantView from homeassistant.components.http import HomeAssistantView
@ -131,16 +131,17 @@ class FileUploadView(HomeAssistantView):
reader = await request.multipart() reader = await request.multipart()
file_field_reader = await reader.next() file_field_reader = await reader.next()
filename: str | None
if ( if (
not isinstance(file_field_reader, BodyPartReader) not isinstance(file_field_reader, BodyPartReader)
or file_field_reader.name != "file" or file_field_reader.name != "file"
or file_field_reader.filename is None or (filename := file_field_reader.filename) is None
): ):
raise vol.Invalid("Expected a file") raise vol.Invalid("Expected a file")
try: try:
raise_if_invalid_filename(file_field_reader.filename) raise_if_invalid_filename(filename)
except ValueError as err: except ValueError as err:
raise web.HTTPBadRequest from err raise web.HTTPBadRequest from err
@ -152,39 +153,46 @@ class FileUploadView(HomeAssistantView):
file_upload_data: FileUploadData = hass.data[DOMAIN] file_upload_data: FileUploadData = hass.data[DOMAIN]
file_dir = file_upload_data.file_dir(file_id) file_dir = file_upload_data.file_dir(file_id)
queue: janus.Queue[bytes | None] = janus.Queue() queue: SimpleQueue[
tuple[bytes, asyncio.Future[None] | None] | None
] = SimpleQueue()
def _sync_queue_consumer( def _sync_queue_consumer() -> None:
sync_q: janus.SyncQueue[bytes | None], _file_name: str
) -> None:
file_dir.mkdir() file_dir.mkdir()
with (file_dir / _file_name).open("wb") as file_handle: with (file_dir / filename).open("wb") as file_handle:
while True: while True:
_chunk = sync_q.get() if (_chunk_future := queue.get()) is None:
if _chunk is None:
break break
_chunk, _future = _chunk_future
if _future is not None:
hass.loop.call_soon_threadsafe(_future.set_result, None)
file_handle.write(_chunk) file_handle.write(_chunk)
sync_q.task_done()
fut: asyncio.Future[None] | None = None fut: asyncio.Future[None] | None = None
try: try:
fut = hass.async_add_executor_job( fut = hass.async_add_executor_job(_sync_queue_consumer)
_sync_queue_consumer, megabytes_sending = 0
queue.sync_q,
file_field_reader.filename,
)
while chunk := await file_field_reader.read_chunk(ONE_MEGABYTE): while chunk := await file_field_reader.read_chunk(ONE_MEGABYTE):
queue.async_q.put_nowait(chunk) megabytes_sending += 1
if queue.async_q.qsize() > 5: # Allow up to 5 MB buffer size if megabytes_sending % 5 != 0:
await queue.async_q.join() queue.put_nowait((chunk, None))
queue.async_q.put_nowait(None) # terminate queue consumer continue
chunk_future = hass.loop.create_future()
queue.put_nowait((chunk, chunk_future))
await asyncio.wait(
(fut, chunk_future), return_when=asyncio.FIRST_COMPLETED
)
if fut.done():
# The executor job failed
break
queue.put_nowait(None) # terminate queue consumer
finally: finally:
if fut is not None: if fut is not None:
await fut await fut
file_upload_data.files[file_id] = file_field_reader.filename file_upload_data.files[file_id] = filename
return self.json({"file_id": file_id}) return self.json({"file_id": file_id})

View File

@ -5,6 +5,5 @@
"dependencies": ["http"], "dependencies": ["http"],
"documentation": "https://www.home-assistant.io/integrations/file_upload", "documentation": "https://www.home-assistant.io/integrations/file_upload",
"integration_type": "system", "integration_type": "system",
"quality_scale": "internal", "quality_scale": "internal"
"requirements": ["janus==1.0.0"]
} }

View File

@ -34,7 +34,6 @@ home-assistant-frontend==20240301.0
home-assistant-intents==2024.2.28 home-assistant-intents==2024.2.28
httpx==0.27.0 httpx==0.27.0
ifaddr==0.2.0 ifaddr==0.2.0
janus==1.0.0
Jinja2==3.1.3 Jinja2==3.1.3
lru-dict==1.3.0 lru-dict==1.3.0
mutagen==1.47.0 mutagen==1.47.0

View File

@ -1156,9 +1156,6 @@ iperf3==0.1.11
# homeassistant.components.gogogate2 # homeassistant.components.gogogate2
ismartgate==5.0.1 ismartgate==5.0.1
# homeassistant.components.file_upload
janus==1.0.0
# homeassistant.components.abode # homeassistant.components.abode
jaraco.abode==3.3.0 jaraco.abode==3.3.0

View File

@ -934,9 +934,6 @@ intellifire4py==2.2.2
# homeassistant.components.gogogate2 # homeassistant.components.gogogate2
ismartgate==5.0.1 ismartgate==5.0.1
# homeassistant.components.file_upload
janus==1.0.0
# homeassistant.components.abode # homeassistant.components.abode
jaraco.abode==3.3.0 jaraco.abode==3.3.0

View File

@ -1,4 +1,5 @@
"""Test the File Upload integration.""" """Test the File Upload integration."""
from contextlib import contextmanager
from pathlib import Path from pathlib import Path
from random import getrandbits from random import getrandbits
from unittest.mock import patch from unittest.mock import patch
@ -117,3 +118,41 @@ async def test_upload_with_wrong_key_fails(
res = await client.post("/api/file_upload", data={"wrong_key": large_file_io}) res = await client.post("/api/file_upload", data={"wrong_key": large_file_io})
assert res.status == 400 assert res.status == 400
async def test_upload_large_file_fails(
hass: HomeAssistant, hass_client: ClientSessionGenerator, large_file_io
) -> None:
"""Test uploading large file."""
assert await async_setup_component(hass, "file_upload", {})
client = await hass_client()
@contextmanager
def _mock_open(*args, **kwargs):
yield MockPathOpen()
class MockPathOpen:
def __init__(self, *args, **kwargs) -> None:
pass
def write(self, data: bytes) -> None:
raise OSError("Boom")
with patch(
# Patch temp dir name to avoid tests fail running in parallel
"homeassistant.components.file_upload.TEMP_DIR_NAME",
file_upload.TEMP_DIR_NAME + f"-{getrandbits(10):03x}",
), patch(
# Patch one megabyte to 8 bytes to prevent having to use big files in tests
"homeassistant.components.file_upload.ONE_MEGABYTE",
8,
), patch(
"homeassistant.components.file_upload.Path.open", return_value=_mock_open()
):
res = await client.post("/api/file_upload", data={"file": large_file_io})
assert res.status == 500
response = await res.content.read()
assert b"Boom" in response

View File

@ -548,7 +548,7 @@ async def test_discovery_requirements_mqtt(hass: HomeAssistant) -> None:
) as mock_process: ) as mock_process:
await async_get_integration_with_requirements(hass, "mqtt_comp") await async_get_integration_with_requirements(hass, "mqtt_comp")
assert len(mock_process.mock_calls) == 3 # mqtt also depends on http assert len(mock_process.mock_calls) == 2 # mqtt also depends on http
assert mock_process.mock_calls[0][1][1] == mqtt.requirements assert mock_process.mock_calls[0][1][1] == mqtt.requirements