mirror of
https://github.com/home-assistant/supervisor.git
synced 2025-04-19 10:47:15 +00:00
Extend backup upload API with file name parameter (#5568)
* Extend backup upload API with file name parameter Add a query parameter which allows to specify the file name on upload. All locations will store the backup with the same file name. * ruff format * Update tests to cover bad filename * Fix ruff check error * Drop unnecessary logging
This commit is contained in:
parent
2a8d2d2b48
commit
1b0aa30881
@ -14,6 +14,7 @@ from typing import Any
|
||||
from aiohttp import web
|
||||
from aiohttp.hdrs import CONTENT_DISPOSITION
|
||||
import voluptuous as vol
|
||||
from voluptuous.humanize import humanize_error
|
||||
|
||||
from ..backups.backup import Backup
|
||||
from ..backups.const import LOCATION_CLOUD_BACKUP, LOCATION_TYPE
|
||||
@ -503,6 +504,14 @@ class APIBackups(CoreSysAttributes):
|
||||
if location and location != LOCATION_CLOUD_BACKUP:
|
||||
tmp_path = location.local_where
|
||||
|
||||
filename: str | None = None
|
||||
if ATTR_FILENAME in request.query:
|
||||
filename = request.query.get(ATTR_FILENAME)
|
||||
try:
|
||||
vol.Match(RE_BACKUP_FILENAME)(filename)
|
||||
except vol.Invalid as ex:
|
||||
raise APIError(humanize_error(filename, ex)) from None
|
||||
|
||||
with TemporaryDirectory(dir=tmp_path.as_posix()) as temp_dir:
|
||||
tar_file = Path(temp_dir, "backup.tar")
|
||||
reader = await request.multipart()
|
||||
@ -529,7 +538,10 @@ class APIBackups(CoreSysAttributes):
|
||||
|
||||
backup = await asyncio.shield(
|
||||
self.sys_backups.import_backup(
|
||||
tar_file, location=location, additional_locations=locations
|
||||
tar_file,
|
||||
filename,
|
||||
location=location,
|
||||
additional_locations=locations,
|
||||
)
|
||||
)
|
||||
|
||||
|
@ -365,6 +365,7 @@ class BackupManager(FileConfiguration, JobGroup):
|
||||
async def import_backup(
|
||||
self,
|
||||
tar_file: Path,
|
||||
filename: str | None = None,
|
||||
location: LOCATION_TYPE = None,
|
||||
additional_locations: list[LOCATION_TYPE] | None = None,
|
||||
) -> Backup | None:
|
||||
@ -376,9 +377,13 @@ class BackupManager(FileConfiguration, JobGroup):
|
||||
return None
|
||||
|
||||
# Move backup to destination folder
|
||||
tar_origin = Path(self._get_base_path(location), f"{backup.slug}.tar")
|
||||
if filename:
|
||||
tar_file = Path(self._get_base_path(location), Path(filename).name)
|
||||
else:
|
||||
tar_file = Path(self._get_base_path(location), f"{backup.slug}.tar")
|
||||
|
||||
try:
|
||||
backup.tarfile.rename(tar_origin)
|
||||
backup.tarfile.rename(tar_file)
|
||||
|
||||
except OSError as err:
|
||||
if err.errno == errno.EBADMSG and location in {LOCATION_CLOUD_BACKUP, None}:
|
||||
@ -387,7 +392,7 @@ class BackupManager(FileConfiguration, JobGroup):
|
||||
return None
|
||||
|
||||
# Load new backup
|
||||
backup = Backup(self.coresys, tar_origin, backup.slug, location, backup.data)
|
||||
backup = Backup(self.coresys, tar_file, backup.slug, location, backup.data)
|
||||
if not await backup.load():
|
||||
# Remove invalid backup from location it was moved to
|
||||
backup.tarfile.unlink()
|
||||
|
@ -718,6 +718,43 @@ async def test_upload_duplicate_backup_new_location(
|
||||
assert coresys.backups.get("7fed74c8").location is None
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("filename", "expected_status"),
|
||||
[("good.tar", 200), ("../bad.tar", 400), ("bad", 400)],
|
||||
)
|
||||
@pytest.mark.usefixtures("tmp_supervisor_data")
|
||||
async def test_upload_with_filename(
|
||||
api_client: TestClient, coresys: CoreSys, filename: str, expected_status: int
|
||||
):
|
||||
"""Test uploading a backup to multiple locations."""
|
||||
backup_file = get_fixture_path("backup_example.tar")
|
||||
|
||||
with backup_file.open("rb") as file, MultipartWriter("form-data") as mp:
|
||||
mp.append(file)
|
||||
resp = await api_client.post(
|
||||
f"/backups/new/upload?filename={filename}", data=mp
|
||||
)
|
||||
|
||||
assert resp.status == expected_status
|
||||
body = await resp.json()
|
||||
if expected_status != 200:
|
||||
assert (
|
||||
body["message"]
|
||||
== r"does not match regular expression ^[^\\\/]+\.tar$."
|
||||
+ f" Got '{filename}'"
|
||||
)
|
||||
return
|
||||
|
||||
assert body["data"]["slug"] == "7fed74c8"
|
||||
|
||||
orig_backup = coresys.config.path_backup / filename
|
||||
assert orig_backup.exists()
|
||||
assert coresys.backups.get("7fed74c8").all_locations == {
|
||||
None: {"path": orig_backup, "protected": False}
|
||||
}
|
||||
assert coresys.backups.get("7fed74c8").location is None
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("method", "url"),
|
||||
[
|
||||
|
Loading…
x
Reference in New Issue
Block a user