"""Local backup support for Core and Container installations."""

from __future__ import annotations

import asyncio
from collections.abc import AsyncIterator, Callable, Coroutine
from concurrent.futures import CancelledError, Future
import copy
from dataclasses import dataclass, replace
from io import BytesIO
import json
import os
from pathlib import Path, PurePath
from queue import SimpleQueue
import tarfile
import threading
from typing import IO, Any, Self, cast

import aiohttp
from securetar import SecureTarError, SecureTarFile, SecureTarReadError

from homeassistant.backup_restore import password_to_key
from homeassistant.core import HomeAssistant
from homeassistant.exceptions import HomeAssistantError
from homeassistant.util import dt as dt_util
from homeassistant.util.json import JsonObjectType, json_loads_object

from .const import BUF_SIZE, LOGGER
from .models import AddonInfo, AgentBackup, Folder


class DecryptError(HomeAssistantError):
    """Error during decryption."""

    _message = "Unexpected error during decryption."


class EncryptError(HomeAssistantError):
    """Error during encryption."""

    _message = "Unexpected error during encryption."


class UnsupportedSecureTarVersion(DecryptError):
    """Unsupported securetar version."""

    _message = "Unsupported securetar version."


class IncorrectPassword(DecryptError):
    """Invalid password or corrupted backup."""

    _message = "Invalid password or corrupted backup."


class BackupEmpty(DecryptError):
    """No tar files found in the backup."""

    _message = "No tar files found in the backup."


class AbortCipher(HomeAssistantError):
    """Abort the cipher operation."""

    _message = "Abort cipher operation."


def make_backup_dir(path: Path) -> None:
    """Create a backup directory if it does not exist."""
    path.mkdir(exist_ok=True)


def read_backup(backup_path: Path) -> AgentBackup:
    """Read a backup from disk."""

    with tarfile.open(backup_path, "r:", bufsize=BUF_SIZE) as backup_file:
        if not (data_file := backup_file.extractfile("./backup.json")):
            raise KeyError("backup.json not found in tar file")
        data = json_loads_object(data_file.read())
        addons = [
            AddonInfo(
                name=cast(str, addon["name"]),
                slug=cast(str, addon["slug"]),
                version=cast(str, addon["version"]),
            )
            for addon in cast(list[JsonObjectType], data.get("addons", []))
        ]

        folders = [
            Folder(folder)
            for folder in cast(list[str], data.get("folders", []))
            if folder != "homeassistant"
        ]

        homeassistant_included = False
        homeassistant_version: str | None = None
        database_included = False
        if (
            homeassistant := cast(JsonObjectType, data.get("homeassistant"))
        ) and "version" in homeassistant:
            homeassistant_included = True
            homeassistant_version = cast(str, homeassistant["version"])
            database_included = not cast(
                bool, homeassistant.get("exclude_database", False)
            )

        extra_metadata = cast(dict[str, bool | str], data.get("extra", {}))
        date = extra_metadata.get("supervisor.backup_request_date", data["date"])

        return AgentBackup(
            addons=addons,
            backup_id=cast(str, data["slug"]),
            database_included=database_included,
            date=cast(str, date),
            extra_metadata=extra_metadata,
            folders=folders,
            homeassistant_included=homeassistant_included,
            homeassistant_version=homeassistant_version,
            name=cast(str, data["name"]),
            protected=cast(bool, data.get("protected", False)),
            size=backup_path.stat().st_size,
        )


def suggested_filename_from_name_date(name: str, date_str: str) -> str:
    """Suggest a filename for the backup."""
    date = dt_util.parse_datetime(date_str, raise_on_error=True)
    return "_".join(f"{name} {date.strftime('%Y-%m-%d %H.%M %S%f')}.tar".split())


def suggested_filename(backup: AgentBackup) -> str:
    """Suggest a filename for the backup."""
    return suggested_filename_from_name_date(backup.name, backup.date)


def validate_password(path: Path, password: str | None) -> bool:
    """Validate the password."""
    with tarfile.open(path, "r:", bufsize=BUF_SIZE) as backup_file:
        compressed = False
        ha_tar_name = "homeassistant.tar"
        try:
            ha_tar = backup_file.extractfile(ha_tar_name)
        except KeyError:
            compressed = True
            ha_tar_name = "homeassistant.tar.gz"
            try:
                ha_tar = backup_file.extractfile(ha_tar_name)
            except KeyError:
                LOGGER.error("No homeassistant.tar or homeassistant.tar.gz found")
                return False
        try:
            with SecureTarFile(
                path,  # Not used
                gzip=compressed,
                key=password_to_key(password) if password is not None else None,
                mode="r",
                fileobj=ha_tar,
            ):
                # If we can read the tar file, the password is correct
                return True
        except tarfile.ReadError:
            LOGGER.debug("Invalid password")
            return False
        except Exception:  # noqa: BLE001
            LOGGER.exception("Unexpected error validating password")
    return False


class AsyncIteratorReader:
    """Wrap an AsyncIterator."""

    def __init__(self, hass: HomeAssistant, stream: AsyncIterator[bytes]) -> None:
        """Initialize the wrapper."""
        self._aborted = False
        self._hass = hass
        self._stream = stream
        self._buffer: bytes | None = None
        self._next_future: Future[bytes | None] | None = None
        self._pos: int = 0

    async def _next(self) -> bytes | None:
        """Get the next chunk from the iterator."""
        return await anext(self._stream, None)

    def abort(self) -> None:
        """Abort the reader."""
        self._aborted = True
        if self._next_future is not None:
            self._next_future.cancel()

    def read(self, n: int = -1, /) -> bytes:
        """Read data from the iterator."""
        result = bytearray()
        while n < 0 or len(result) < n:
            if not self._buffer:
                self._next_future = asyncio.run_coroutine_threadsafe(
                    self._next(), self._hass.loop
                )
                if self._aborted:
                    self._next_future.cancel()
                    raise AbortCipher
                try:
                    self._buffer = self._next_future.result()
                except CancelledError as err:
                    raise AbortCipher from err
                self._pos = 0
            if not self._buffer:
                # The stream is exhausted
                break
            chunk = self._buffer[self._pos : self._pos + n]
            result.extend(chunk)
            n -= len(chunk)
            self._pos += len(chunk)
            if self._pos == len(self._buffer):
                self._buffer = None
        return bytes(result)

    def close(self) -> None:
        """Close the iterator."""


class AsyncIteratorWriter:
    """Wrap an AsyncIterator."""

    def __init__(self, hass: HomeAssistant) -> None:
        """Initialize the wrapper."""
        self._aborted = False
        self._hass = hass
        self._pos: int = 0
        self._queue: asyncio.Queue[bytes | None] = asyncio.Queue(maxsize=1)
        self._write_future: Future[bytes | None] | None = None

    def __aiter__(self) -> Self:
        """Return the iterator."""
        return self

    async def __anext__(self) -> bytes:
        """Get the next chunk from the iterator."""
        if data := await self._queue.get():
            return data
        raise StopAsyncIteration

    def abort(self) -> None:
        """Abort the writer."""
        self._aborted = True
        if self._write_future is not None:
            self._write_future.cancel()

    def tell(self) -> int:
        """Return the current position in the iterator."""
        return self._pos

    def write(self, s: bytes, /) -> int:
        """Write data to the iterator."""
        self._write_future = asyncio.run_coroutine_threadsafe(
            self._queue.put(s), self._hass.loop
        )
        if self._aborted:
            self._write_future.cancel()
            raise AbortCipher
        try:
            self._write_future.result()
        except CancelledError as err:
            raise AbortCipher from err
        self._pos += len(s)
        return len(s)


def validate_password_stream(
    input_stream: IO[bytes],
    password: str | None,
) -> None:
    """Decrypt a backup."""
    with (
        tarfile.open(fileobj=input_stream, mode="r|", bufsize=BUF_SIZE) as input_tar,
    ):
        for obj in input_tar:
            if not obj.name.endswith((".tar", ".tgz", ".tar.gz")):
                continue
            istf = SecureTarFile(
                None,  # Not used
                gzip=False,
                key=password_to_key(password) if password is not None else None,
                mode="r",
                fileobj=input_tar.extractfile(obj),
            )
            with istf.decrypt(obj) as decrypted:
                if istf.securetar_header.plaintext_size is None:
                    raise UnsupportedSecureTarVersion
                try:
                    decrypted.read(1)  # Read a single byte to trigger the decryption
                except SecureTarReadError as err:
                    raise IncorrectPassword from err
                return
    raise BackupEmpty


def decrypt_backup(
    input_stream: IO[bytes],
    output_stream: IO[bytes],
    password: str | None,
    on_done: Callable[[Exception | None], None],
    minimum_size: int,
    nonces: list[bytes],
) -> None:
    """Decrypt a backup."""
    error: Exception | None = None
    try:
        try:
            with (
                tarfile.open(
                    fileobj=input_stream, mode="r|", bufsize=BUF_SIZE
                ) as input_tar,
                tarfile.open(
                    fileobj=output_stream, mode="w|", bufsize=BUF_SIZE
                ) as output_tar,
            ):
                _decrypt_backup(input_tar, output_tar, password)
        except (DecryptError, SecureTarError, tarfile.TarError) as err:
            LOGGER.warning("Error decrypting backup: %s", err)
            error = err
        else:
            # Pad the output stream to the requested minimum size
            padding = max(minimum_size - output_stream.tell(), 0)
            output_stream.write(b"\0" * padding)
        finally:
            # Write an empty chunk to signal the end of the stream
            output_stream.write(b"")
    except AbortCipher:
        LOGGER.debug("Cipher operation aborted")
    finally:
        on_done(error)


def _decrypt_backup(
    input_tar: tarfile.TarFile,
    output_tar: tarfile.TarFile,
    password: str | None,
) -> None:
    """Decrypt a backup."""
    for obj in input_tar:
        # We compare with PurePath to avoid issues with different path separators,
        # for example when backup.json is added as "./backup.json"
        if PurePath(obj.name) == PurePath("backup.json"):
            # Rewrite the backup.json file to indicate that the backup is decrypted
            if not (reader := input_tar.extractfile(obj)):
                raise DecryptError
            metadata = json_loads_object(reader.read())
            metadata["protected"] = False
            updated_metadata_b = json.dumps(metadata).encode()
            metadata_obj = copy.deepcopy(obj)
            metadata_obj.size = len(updated_metadata_b)
            output_tar.addfile(metadata_obj, BytesIO(updated_metadata_b))
            continue
        if not obj.name.endswith((".tar", ".tgz", ".tar.gz")):
            output_tar.addfile(obj, input_tar.extractfile(obj))
            continue
        istf = SecureTarFile(
            None,  # Not used
            gzip=False,
            key=password_to_key(password) if password is not None else None,
            mode="r",
            fileobj=input_tar.extractfile(obj),
        )
        with istf.decrypt(obj) as decrypted:
            if (plaintext_size := istf.securetar_header.plaintext_size) is None:
                raise UnsupportedSecureTarVersion
            decrypted_obj = copy.deepcopy(obj)
            decrypted_obj.size = plaintext_size
            output_tar.addfile(decrypted_obj, decrypted)


def encrypt_backup(
    input_stream: IO[bytes],
    output_stream: IO[bytes],
    password: str | None,
    on_done: Callable[[Exception | None], None],
    minimum_size: int,
    nonces: list[bytes],
) -> None:
    """Encrypt a backup."""
    error: Exception | None = None
    try:
        try:
            with (
                tarfile.open(
                    fileobj=input_stream, mode="r|", bufsize=BUF_SIZE
                ) as input_tar,
                tarfile.open(
                    fileobj=output_stream, mode="w|", bufsize=BUF_SIZE
                ) as output_tar,
            ):
                _encrypt_backup(input_tar, output_tar, password, nonces)
        except (EncryptError, SecureTarError, tarfile.TarError) as err:
            LOGGER.warning("Error encrypting backup: %s", err)
            error = err
        else:
            # Pad the output stream to the requested minimum size
            padding = max(minimum_size - output_stream.tell(), 0)
            output_stream.write(b"\0" * padding)
        finally:
            # Write an empty chunk to signal the end of the stream
            output_stream.write(b"")
    except AbortCipher:
        LOGGER.debug("Cipher operation aborted")
    finally:
        on_done(error)


def _encrypt_backup(
    input_tar: tarfile.TarFile,
    output_tar: tarfile.TarFile,
    password: str | None,
    nonces: list[bytes],
) -> None:
    """Encrypt a backup."""
    inner_tar_idx = 0
    for obj in input_tar:
        # We compare with PurePath to avoid issues with different path separators,
        # for example when backup.json is added as "./backup.json"
        if PurePath(obj.name) == PurePath("backup.json"):
            # Rewrite the backup.json file to indicate that the backup is encrypted
            if not (reader := input_tar.extractfile(obj)):
                raise EncryptError
            metadata = json_loads_object(reader.read())
            metadata["protected"] = True
            updated_metadata_b = json.dumps(metadata).encode()
            metadata_obj = copy.deepcopy(obj)
            metadata_obj.size = len(updated_metadata_b)
            output_tar.addfile(metadata_obj, BytesIO(updated_metadata_b))
            continue
        if not obj.name.endswith((".tar", ".tgz", ".tar.gz")):
            output_tar.addfile(obj, input_tar.extractfile(obj))
            continue
        istf = SecureTarFile(
            None,  # Not used
            gzip=False,
            key=password_to_key(password) if password is not None else None,
            mode="r",
            fileobj=input_tar.extractfile(obj),
            nonce=nonces[inner_tar_idx],
        )
        inner_tar_idx += 1
        with istf.encrypt(obj) as encrypted:
            encrypted_obj = copy.deepcopy(obj)
            encrypted_obj.size = encrypted.encrypted_size
            output_tar.addfile(encrypted_obj, encrypted)


@dataclass(kw_only=True)
class _CipherWorkerStatus:
    done: asyncio.Event
    error: Exception | None = None
    reader: AsyncIteratorReader
    thread: threading.Thread
    writer: AsyncIteratorWriter


class _CipherBackupStreamer:
    """Encrypt or decrypt a backup."""

    _cipher_func: Callable[
        [
            IO[bytes],
            IO[bytes],
            str | None,
            Callable[[Exception | None], None],
            int,
            list[bytes],
        ],
        None,
    ]

    def __init__(
        self,
        hass: HomeAssistant,
        backup: AgentBackup,
        open_stream: Callable[[], Coroutine[Any, Any, AsyncIterator[bytes]]],
        password: str | None,
    ) -> None:
        """Initialize."""
        self._workers: list[_CipherWorkerStatus] = []
        self._backup = backup
        self._hass = hass
        self._open_stream = open_stream
        self._password = password
        self._nonces: list[bytes] = []

    def size(self) -> int:
        """Return the maximum size of the decrypted or encrypted backup."""
        return self._backup.size + self._num_tar_files() * tarfile.RECORDSIZE

    def _num_tar_files(self) -> int:
        """Return the number of inner tar files."""
        b = self._backup
        return len(b.addons) + len(b.folders) + b.homeassistant_included + 1

    async def open_stream(self) -> AsyncIterator[bytes]:
        """Open a stream."""

        def on_done(error: Exception | None) -> None:
            """Call by the worker thread when it's done."""
            worker_status.error = error
            self._hass.loop.call_soon_threadsafe(worker_status.done.set)

        stream = await self._open_stream()
        reader = AsyncIteratorReader(self._hass, stream)
        writer = AsyncIteratorWriter(self._hass)
        worker = threading.Thread(
            target=self._cipher_func,
            args=[reader, writer, self._password, on_done, self.size(), self._nonces],
        )
        worker_status = _CipherWorkerStatus(
            done=asyncio.Event(), reader=reader, thread=worker, writer=writer
        )
        self._workers.append(worker_status)
        worker.start()
        return writer

    async def wait(self) -> None:
        """Wait for the worker threads to finish."""
        for worker in self._workers:
            worker.reader.abort()
            worker.writer.abort()
        await asyncio.gather(*(worker.done.wait() for worker in self._workers))


class DecryptedBackupStreamer(_CipherBackupStreamer):
    """Decrypt a backup."""

    _cipher_func = staticmethod(decrypt_backup)

    def backup(self) -> AgentBackup:
        """Return the decrypted backup."""
        return replace(self._backup, protected=False, size=self.size())


class EncryptedBackupStreamer(_CipherBackupStreamer):
    """Encrypt a backup."""

    def __init__(
        self,
        hass: HomeAssistant,
        backup: AgentBackup,
        open_stream: Callable[[], Coroutine[Any, Any, AsyncIterator[bytes]]],
        password: str | None,
    ) -> None:
        """Initialize."""
        super().__init__(hass, backup, open_stream, password)
        self._nonces = [os.urandom(16) for _ in range(self._num_tar_files())]

    _cipher_func = staticmethod(encrypt_backup)

    def backup(self) -> AgentBackup:
        """Return the encrypted backup."""
        return replace(self._backup, protected=True, size=self.size())


async def receive_file(
    hass: HomeAssistant, contents: aiohttp.BodyPartReader, path: Path
) -> None:
    """Receive a file from a stream and write it to a file."""
    queue: SimpleQueue[tuple[bytes, asyncio.Future[None] | None] | None] = SimpleQueue()

    def _sync_queue_consumer() -> None:
        with path.open("wb") as file_handle:
            while True:
                if (_chunk_future := queue.get()) is None:
                    break
                _chunk, _future = _chunk_future
                if _future is not None:
                    hass.loop.call_soon_threadsafe(_future.set_result, None)
                file_handle.write(_chunk)

    fut: asyncio.Future[None] | None = None
    try:
        fut = hass.async_add_executor_job(_sync_queue_consumer)
        megabytes_sending = 0
        while chunk := await contents.read_chunk(BUF_SIZE):
            megabytes_sending += 1
            if megabytes_sending % 5 != 0:
                queue.put_nowait((chunk, None))
                continue

            chunk_future = hass.loop.create_future()
            queue.put_nowait((chunk, chunk_future))
            await asyncio.wait(
                (fut, chunk_future),
                return_when=asyncio.FIRST_COMPLETED,
            )
            if fut.done():
                # The executor job failed
                break

        queue.put_nowait(None)  # terminate queue consumer
    finally:
        if fut is not None:
            await fut