From 1b0aa30881312cbc875c2556a4c3eb484a4575fd Mon Sep 17 00:00:00 2001 From: Stefan Agner Date: Mon, 27 Jan 2025 10:01:29 +0100 Subject: [PATCH] Extend backup upload API with file name parameter (#5568) * Extend backup upload API with file name parameter Add a query parameter which allows to specify the file name on upload. All locations will store the backup with the same file name. * ruff format * Update tests to cover bad filename * Fix ruff check error * Drop unnecessary logging --- supervisor/api/backups.py | 14 ++++++++++++- supervisor/backups/manager.py | 11 ++++++++--- tests/api/test_backups.py | 37 +++++++++++++++++++++++++++++++++++ 3 files changed, 58 insertions(+), 4 deletions(-) diff --git a/supervisor/api/backups.py b/supervisor/api/backups.py index f8c0e7bb3..573499055 100644 --- a/supervisor/api/backups.py +++ b/supervisor/api/backups.py @@ -14,6 +14,7 @@ from typing import Any from aiohttp import web from aiohttp.hdrs import CONTENT_DISPOSITION import voluptuous as vol +from voluptuous.humanize import humanize_error from ..backups.backup import Backup from ..backups.const import LOCATION_CLOUD_BACKUP, LOCATION_TYPE @@ -503,6 +504,14 @@ class APIBackups(CoreSysAttributes): if location and location != LOCATION_CLOUD_BACKUP: tmp_path = location.local_where + filename: str | None = None + if ATTR_FILENAME in request.query: + filename = request.query.get(ATTR_FILENAME) + try: + vol.Match(RE_BACKUP_FILENAME)(filename) + except vol.Invalid as ex: + raise APIError(humanize_error(filename, ex)) from None + with TemporaryDirectory(dir=tmp_path.as_posix()) as temp_dir: tar_file = Path(temp_dir, "backup.tar") reader = await request.multipart() @@ -529,7 +538,10 @@ class APIBackups(CoreSysAttributes): backup = await asyncio.shield( self.sys_backups.import_backup( - tar_file, location=location, additional_locations=locations + tar_file, + filename, + location=location, + additional_locations=locations, ) ) diff --git a/supervisor/backups/manager.py b/supervisor/backups/manager.py index 0f26cfff5..76d47780e 100644 --- a/supervisor/backups/manager.py +++ b/supervisor/backups/manager.py @@ -365,6 +365,7 @@ class BackupManager(FileConfiguration, JobGroup): async def import_backup( self, tar_file: Path, + filename: str | None = None, location: LOCATION_TYPE = None, additional_locations: list[LOCATION_TYPE] | None = None, ) -> Backup | None: @@ -376,9 +377,13 @@ class BackupManager(FileConfiguration, JobGroup): return None # Move backup to destination folder - tar_origin = Path(self._get_base_path(location), f"{backup.slug}.tar") + if filename: + tar_file = Path(self._get_base_path(location), Path(filename).name) + else: + tar_file = Path(self._get_base_path(location), f"{backup.slug}.tar") + try: - backup.tarfile.rename(tar_origin) + backup.tarfile.rename(tar_file) except OSError as err: if err.errno == errno.EBADMSG and location in {LOCATION_CLOUD_BACKUP, None}: @@ -387,7 +392,7 @@ class BackupManager(FileConfiguration, JobGroup): return None # Load new backup - backup = Backup(self.coresys, tar_origin, backup.slug, location, backup.data) + backup = Backup(self.coresys, tar_file, backup.slug, location, backup.data) if not await backup.load(): # Remove invalid backup from location it was moved to backup.tarfile.unlink() diff --git a/tests/api/test_backups.py b/tests/api/test_backups.py index 6003bd339..57e32c485 100644 --- a/tests/api/test_backups.py +++ b/tests/api/test_backups.py @@ -718,6 +718,43 @@ async def test_upload_duplicate_backup_new_location( assert coresys.backups.get("7fed74c8").location is None +@pytest.mark.parametrize( + ("filename", "expected_status"), + [("good.tar", 200), ("../bad.tar", 400), ("bad", 400)], +) +@pytest.mark.usefixtures("tmp_supervisor_data") +async def test_upload_with_filename( + api_client: TestClient, coresys: CoreSys, filename: str, expected_status: int +): + """Test uploading a backup to multiple locations.""" + backup_file = get_fixture_path("backup_example.tar") + + with backup_file.open("rb") as file, MultipartWriter("form-data") as mp: + mp.append(file) + resp = await api_client.post( + f"/backups/new/upload?filename={filename}", data=mp + ) + + assert resp.status == expected_status + body = await resp.json() + if expected_status != 200: + assert ( + body["message"] + == r"does not match regular expression ^[^\\\/]+\.tar$." + + f" Got '{filename}'" + ) + return + + assert body["data"]["slug"] == "7fed74c8" + + orig_backup = coresys.config.path_backup / filename + assert orig_backup.exists() + assert coresys.backups.get("7fed74c8").all_locations == { + None: {"path": orig_backup, "protected": False} + } + assert coresys.backups.get("7fed74c8").location is None + + @pytest.mark.parametrize( ("method", "url"), [