"""Block blocking calls being done in asyncio."""

import builtins
from collections.abc import Callable
from contextlib import suppress
from dataclasses import dataclass
import glob
from http.client import HTTPConnection
import importlib
import os
from pathlib import Path
from ssl import SSLContext
import sys
import threading
import time
from typing import Any

from .helpers.frame import get_current_frame
from .util.loop import protect_loop

_IN_TESTS = "unittest" in sys.modules

ALLOWED_FILE_PREFIXES = ("/proc",)


def _check_import_call_allowed(mapped_args: dict[str, Any]) -> bool:
    # If the module is already imported, we can ignore it.
    return bool((args := mapped_args.get("args")) and args[0] in sys.modules)


def _check_file_allowed(mapped_args: dict[str, Any]) -> bool:
    # If the file is in /proc we can ignore it.
    args = mapped_args["args"]
    path = args[0] if type(args[0]) is str else str(args[0])
    return path.startswith(ALLOWED_FILE_PREFIXES)


def _check_sleep_call_allowed(mapped_args: dict[str, Any]) -> bool:
    #
    # Avoid extracting the stack unless we need to since it
    # will have to access the linecache which can do blocking
    # I/O and we are trying to avoid blocking calls.
    #
    # frame[0] is us
    # frame[1] is raise_for_blocking_call
    # frame[2] is protected_loop_func
    # frame[3] is the offender
    with suppress(ValueError):
        return get_current_frame(4).f_code.co_filename.endswith("pydevd.py")
    return False


def _check_load_verify_locations_call_allowed(mapped_args: dict[str, Any]) -> bool:
    # If only cadata is passed, we can ignore it
    kwargs = mapped_args.get("kwargs")
    return bool(kwargs and len(kwargs) == 1 and "cadata" in kwargs)


@dataclass(slots=True, frozen=True)
class BlockingCall:
    """Class to hold information about a blocking call."""

    original_func: Callable
    object: object
    function: str
    check_allowed: Callable[[dict[str, Any]], bool] | None
    strict: bool
    strict_core: bool
    skip_for_tests: bool


_BLOCKING_CALLS: tuple[BlockingCall, ...] = (
    BlockingCall(
        original_func=HTTPConnection.putrequest,
        object=HTTPConnection,
        function="putrequest",
        check_allowed=None,
        strict=True,
        strict_core=True,
        skip_for_tests=False,
    ),
    BlockingCall(
        original_func=time.sleep,
        object=time,
        function="sleep",
        check_allowed=_check_sleep_call_allowed,
        strict=True,
        strict_core=True,
        skip_for_tests=False,
    ),
    BlockingCall(
        original_func=glob.glob,
        object=glob,
        function="glob",
        check_allowed=None,
        strict=False,
        strict_core=False,
        skip_for_tests=False,
    ),
    BlockingCall(
        original_func=glob.iglob,
        object=glob,
        function="iglob",
        check_allowed=None,
        strict=False,
        strict_core=False,
        skip_for_tests=False,
    ),
    BlockingCall(
        original_func=os.walk,
        object=os,
        function="walk",
        check_allowed=None,
        strict=False,
        strict_core=False,
        skip_for_tests=False,
    ),
    BlockingCall(
        original_func=os.listdir,
        object=os,
        function="listdir",
        check_allowed=None,
        strict=False,
        strict_core=False,
        skip_for_tests=True,
    ),
    BlockingCall(
        original_func=os.scandir,
        object=os,
        function="scandir",
        check_allowed=None,
        strict=False,
        strict_core=False,
        skip_for_tests=True,
    ),
    BlockingCall(
        original_func=builtins.open,
        object=builtins,
        function="open",
        check_allowed=_check_file_allowed,
        strict=False,
        strict_core=False,
        skip_for_tests=True,
    ),
    BlockingCall(
        original_func=importlib.import_module,
        object=importlib,
        function="import_module",
        check_allowed=_check_import_call_allowed,
        strict=False,
        strict_core=False,
        skip_for_tests=True,
    ),
    BlockingCall(
        original_func=SSLContext.load_default_certs,
        object=SSLContext,
        function="load_default_certs",
        check_allowed=None,
        strict=False,
        strict_core=False,
        skip_for_tests=True,
    ),
    BlockingCall(
        original_func=SSLContext.load_verify_locations,
        object=SSLContext,
        function="load_verify_locations",
        check_allowed=_check_load_verify_locations_call_allowed,
        strict=False,
        strict_core=False,
        skip_for_tests=True,
    ),
    BlockingCall(
        original_func=SSLContext.load_cert_chain,
        object=SSLContext,
        function="load_cert_chain",
        check_allowed=None,
        strict=False,
        strict_core=False,
        skip_for_tests=True,
    ),
    BlockingCall(
        original_func=SSLContext.set_default_verify_paths,
        object=SSLContext,
        function="set_default_verify_paths",
        check_allowed=None,
        strict=False,
        strict_core=False,
        skip_for_tests=True,
    ),
    BlockingCall(
        original_func=Path.open,
        object=Path,
        function="open",
        check_allowed=_check_file_allowed,
        strict=False,
        strict_core=False,
        skip_for_tests=True,
    ),
    BlockingCall(
        original_func=Path.read_text,
        object=Path,
        function="read_text",
        check_allowed=_check_file_allowed,
        strict=False,
        strict_core=False,
        skip_for_tests=True,
    ),
    BlockingCall(
        original_func=Path.read_bytes,
        object=Path,
        function="read_bytes",
        check_allowed=_check_file_allowed,
        strict=False,
        strict_core=False,
        skip_for_tests=True,
    ),
    BlockingCall(
        original_func=Path.write_text,
        object=Path,
        function="write_text",
        check_allowed=_check_file_allowed,
        strict=False,
        strict_core=False,
        skip_for_tests=True,
    ),
    BlockingCall(
        original_func=Path.write_bytes,
        object=Path,
        function="write_bytes",
        check_allowed=_check_file_allowed,
        strict=False,
        strict_core=False,
        skip_for_tests=True,
    ),
)


@dataclass(slots=True)
class BlockedCalls:
    """Class to track which calls are blocked."""

    calls: set[BlockingCall]


_BLOCKED_CALLS = BlockedCalls(set())


def enable() -> None:
    """Enable the detection of blocking calls in the event loop."""
    calls = _BLOCKED_CALLS.calls
    if calls:
        raise RuntimeError("Blocking call detection is already enabled")

    loop_thread_id = threading.get_ident()
    for blocking_call in _BLOCKING_CALLS:
        if _IN_TESTS and blocking_call.skip_for_tests:
            continue

        protected_function = protect_loop(
            blocking_call.original_func,
            strict=blocking_call.strict,
            strict_core=blocking_call.strict_core,
            check_allowed=blocking_call.check_allowed,
            loop_thread_id=loop_thread_id,
        )
        setattr(blocking_call.object, blocking_call.function, protected_function)
        calls.add(blocking_call)