mirror of
https://github.com/home-assistant/core.git
synced 2025-11-09 10:59:40 +00:00
Don't encrypt or decrypt unknown files in backup archives (#144495)
This commit is contained in:
@@ -295,13 +295,26 @@ def validate_password_stream(
|
||||
raise BackupEmpty
|
||||
|
||||
|
||||
def _get_expected_archives(backup: AgentBackup) -> set[str]:
|
||||
"""Get the expected archives in the backup."""
|
||||
expected_archives = set()
|
||||
if backup.homeassistant_included:
|
||||
expected_archives.add("homeassistant")
|
||||
for addon in backup.addons:
|
||||
expected_archives.add(addon.slug)
|
||||
for folder in backup.folders:
|
||||
expected_archives.add(folder.value)
|
||||
return expected_archives
|
||||
|
||||
|
||||
def decrypt_backup(
|
||||
backup: AgentBackup,
|
||||
input_stream: IO[bytes],
|
||||
output_stream: IO[bytes],
|
||||
password: str | None,
|
||||
on_done: Callable[[Exception | None], None],
|
||||
minimum_size: int,
|
||||
nonces: list[bytes],
|
||||
nonces: NonceGenerator,
|
||||
) -> None:
|
||||
"""Decrypt a backup."""
|
||||
error: Exception | None = None
|
||||
@@ -315,7 +328,7 @@ def decrypt_backup(
|
||||
fileobj=output_stream, mode="w|", bufsize=BUF_SIZE
|
||||
) as output_tar,
|
||||
):
|
||||
_decrypt_backup(input_tar, output_tar, password)
|
||||
_decrypt_backup(backup, input_tar, output_tar, password)
|
||||
except (DecryptError, SecureTarError, tarfile.TarError) as err:
|
||||
LOGGER.warning("Error decrypting backup: %s", err)
|
||||
error = err
|
||||
@@ -333,15 +346,18 @@ def decrypt_backup(
|
||||
|
||||
|
||||
def _decrypt_backup(
|
||||
backup: AgentBackup,
|
||||
input_tar: tarfile.TarFile,
|
||||
output_tar: tarfile.TarFile,
|
||||
password: str | None,
|
||||
) -> None:
|
||||
"""Decrypt a backup."""
|
||||
expected_archives = _get_expected_archives(backup)
|
||||
for obj in input_tar:
|
||||
# We compare with PurePath to avoid issues with different path separators,
|
||||
# for example when backup.json is added as "./backup.json"
|
||||
if PurePath(obj.name) == PurePath("backup.json"):
|
||||
object_path = PurePath(obj.name)
|
||||
if object_path == PurePath("backup.json"):
|
||||
# Rewrite the backup.json file to indicate that the backup is decrypted
|
||||
if not (reader := input_tar.extractfile(obj)):
|
||||
raise DecryptError
|
||||
@@ -352,7 +368,13 @@ def _decrypt_backup(
|
||||
metadata_obj.size = len(updated_metadata_b)
|
||||
output_tar.addfile(metadata_obj, BytesIO(updated_metadata_b))
|
||||
continue
|
||||
if not obj.name.endswith((".tar", ".tgz", ".tar.gz")):
|
||||
prefix, _, suffix = object_path.name.partition(".")
|
||||
if suffix not in ("tar", "tgz", "tar.gz"):
|
||||
LOGGER.debug("Unknown file %s will not be decrypted", obj.name)
|
||||
output_tar.addfile(obj, input_tar.extractfile(obj))
|
||||
continue
|
||||
if prefix not in expected_archives:
|
||||
LOGGER.debug("Unknown inner tar file %s will not be decrypted", obj.name)
|
||||
output_tar.addfile(obj, input_tar.extractfile(obj))
|
||||
continue
|
||||
istf = SecureTarFile(
|
||||
@@ -371,12 +393,13 @@ def _decrypt_backup(
|
||||
|
||||
|
||||
def encrypt_backup(
|
||||
backup: AgentBackup,
|
||||
input_stream: IO[bytes],
|
||||
output_stream: IO[bytes],
|
||||
password: str | None,
|
||||
on_done: Callable[[Exception | None], None],
|
||||
minimum_size: int,
|
||||
nonces: list[bytes],
|
||||
nonces: NonceGenerator,
|
||||
) -> None:
|
||||
"""Encrypt a backup."""
|
||||
error: Exception | None = None
|
||||
@@ -390,7 +413,7 @@ def encrypt_backup(
|
||||
fileobj=output_stream, mode="w|", bufsize=BUF_SIZE
|
||||
) as output_tar,
|
||||
):
|
||||
_encrypt_backup(input_tar, output_tar, password, nonces)
|
||||
_encrypt_backup(backup, input_tar, output_tar, password, nonces)
|
||||
except (EncryptError, SecureTarError, tarfile.TarError) as err:
|
||||
LOGGER.warning("Error encrypting backup: %s", err)
|
||||
error = err
|
||||
@@ -408,17 +431,20 @@ def encrypt_backup(
|
||||
|
||||
|
||||
def _encrypt_backup(
|
||||
backup: AgentBackup,
|
||||
input_tar: tarfile.TarFile,
|
||||
output_tar: tarfile.TarFile,
|
||||
password: str | None,
|
||||
nonces: list[bytes],
|
||||
nonces: NonceGenerator,
|
||||
) -> None:
|
||||
"""Encrypt a backup."""
|
||||
inner_tar_idx = 0
|
||||
expected_archives = _get_expected_archives(backup)
|
||||
for obj in input_tar:
|
||||
# We compare with PurePath to avoid issues with different path separators,
|
||||
# for example when backup.json is added as "./backup.json"
|
||||
if PurePath(obj.name) == PurePath("backup.json"):
|
||||
object_path = PurePath(obj.name)
|
||||
if object_path == PurePath("backup.json"):
|
||||
# Rewrite the backup.json file to indicate that the backup is encrypted
|
||||
if not (reader := input_tar.extractfile(obj)):
|
||||
raise EncryptError
|
||||
@@ -429,16 +455,21 @@ def _encrypt_backup(
|
||||
metadata_obj.size = len(updated_metadata_b)
|
||||
output_tar.addfile(metadata_obj, BytesIO(updated_metadata_b))
|
||||
continue
|
||||
if not obj.name.endswith((".tar", ".tgz", ".tar.gz")):
|
||||
prefix, _, suffix = object_path.name.partition(".")
|
||||
if suffix not in ("tar", "tgz", "tar.gz"):
|
||||
LOGGER.debug("Unknown file %s will not be encrypted", obj.name)
|
||||
output_tar.addfile(obj, input_tar.extractfile(obj))
|
||||
continue
|
||||
if prefix not in expected_archives:
|
||||
LOGGER.debug("Unknown inner tar file %s will not be encrypted", obj.name)
|
||||
continue
|
||||
istf = SecureTarFile(
|
||||
None, # Not used
|
||||
gzip=False,
|
||||
key=password_to_key(password) if password is not None else None,
|
||||
mode="r",
|
||||
fileobj=input_tar.extractfile(obj),
|
||||
nonce=nonces[inner_tar_idx],
|
||||
nonce=nonces.get(inner_tar_idx),
|
||||
)
|
||||
inner_tar_idx += 1
|
||||
with istf.encrypt(obj) as encrypted:
|
||||
@@ -456,17 +487,33 @@ class _CipherWorkerStatus:
|
||||
writer: AsyncIteratorWriter
|
||||
|
||||
|
||||
class NonceGenerator:
|
||||
"""Generate nonces for encryption."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
"""Initialize the generator."""
|
||||
self._nonces: dict[int, bytes] = {}
|
||||
|
||||
def get(self, index: int) -> bytes:
|
||||
"""Get a nonce for the given index."""
|
||||
if index not in self._nonces:
|
||||
# Generate a new nonce for the given index
|
||||
self._nonces[index] = os.urandom(16)
|
||||
return self._nonces[index]
|
||||
|
||||
|
||||
class _CipherBackupStreamer:
|
||||
"""Encrypt or decrypt a backup."""
|
||||
|
||||
_cipher_func: Callable[
|
||||
[
|
||||
AgentBackup,
|
||||
IO[bytes],
|
||||
IO[bytes],
|
||||
str | None,
|
||||
Callable[[Exception | None], None],
|
||||
int,
|
||||
list[bytes],
|
||||
NonceGenerator,
|
||||
],
|
||||
None,
|
||||
]
|
||||
@@ -484,7 +531,7 @@ class _CipherBackupStreamer:
|
||||
self._hass = hass
|
||||
self._open_stream = open_stream
|
||||
self._password = password
|
||||
self._nonces: list[bytes] = []
|
||||
self._nonces = NonceGenerator()
|
||||
|
||||
def size(self) -> int:
|
||||
"""Return the maximum size of the decrypted or encrypted backup."""
|
||||
@@ -508,7 +555,15 @@ class _CipherBackupStreamer:
|
||||
writer = AsyncIteratorWriter(self._hass)
|
||||
worker = threading.Thread(
|
||||
target=self._cipher_func,
|
||||
args=[reader, writer, self._password, on_done, self.size(), self._nonces],
|
||||
args=[
|
||||
self._backup,
|
||||
reader,
|
||||
writer,
|
||||
self._password,
|
||||
on_done,
|
||||
self.size(),
|
||||
self._nonces,
|
||||
],
|
||||
)
|
||||
worker_status = _CipherWorkerStatus(
|
||||
done=asyncio.Event(), reader=reader, thread=worker, writer=writer
|
||||
@@ -538,17 +593,6 @@ class DecryptedBackupStreamer(_CipherBackupStreamer):
|
||||
class EncryptedBackupStreamer(_CipherBackupStreamer):
|
||||
"""Encrypt a backup."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
hass: HomeAssistant,
|
||||
backup: AgentBackup,
|
||||
open_stream: Callable[[], Coroutine[Any, Any, AsyncIterator[bytes]]],
|
||||
password: str | None,
|
||||
) -> None:
|
||||
"""Initialize."""
|
||||
super().__init__(hass, backup, open_stream, password)
|
||||
self._nonces = [os.urandom(16) for _ in range(self._num_tar_files())]
|
||||
|
||||
_cipher_func = staticmethod(encrypt_backup)
|
||||
|
||||
def backup(self) -> AgentBackup:
|
||||
|
||||
Reference in New Issue
Block a user