import sys
from importlib import import_module
from typing import List, Tuple

from packaging.version import Version


def get_wrong_dependencies_versions(
    dependencies_versions: List[Tuple[str, str, str]],
) -> List[Tuple[str, str, str, str]]:
    """
    Get a list of mismatching dependencies with current version installed.
    E.g., assuming we pass `get_wrong_dependencies_versions([("torch", "==", "1.2.0")]),
        we will check if the current version of `torch` is `==1.2.0`. If not,
        we will return `[("torch", "==", "1.2.0", "<current_installed_version>")]

    We support `<=`, `==`, `>=`

    Args:
        dependencies_versions (List[Tuple[str, str]]): List of dependencies
            we want to check, [("<package_name>", "<version_number_to_check")]

    Returns:
        List[Tuple[str, str, str]]: List of dependencies with wrong version,
            [("<package_name>", "<version_number_to_check", "<current_version>")]
    """
    wrong_dependencies_versions = []
    order_funcs = {
        "==": lambda x, y: x == y,
        ">=": lambda x, y: x >= y,
        "<=": lambda x, y: x <= y,
    }
    for dependency, order, version in dependencies_versions:
        module = import_module(dependency)
        module_version = module.__version__
        if order not in order_funcs:
            raise ValueError(f"order={order} not supported, please use `{', '.join(order_funcs.keys())}`")

        is_okay = order_funcs[order](Version(module_version), Version(version))
        if not is_okay:
            wrong_dependencies_versions.append((dependency, order, version, module_version))
    return wrong_dependencies_versions


def print_warn_for_wrong_dependencies_versions(
    dependencies_versions: List[Tuple[str, str, str]], ask_to_continue: bool = False
):
    wrong_dependencies_versions = get_wrong_dependencies_versions(dependencies_versions)
    for dependency, order, version, module_version in wrong_dependencies_versions:
        print(
            f"Dependency {dependency}{order}{version} is required but found"
            f" version={module_version}, to fix: `pip install"
            f" {dependency}{order}{version}`"
        )
        if ask_to_continue:
            answer = input(f"Would you like to continue with the wrong version of {dependency}? y/n: ")
            if answer.lower() != "y":
                sys.exit(1)


def warn_for_wrong_dependencies_versions(dependencies_versions: List[Tuple[str, str, str]]):
    """
    Decorator to print a warning based on dependencies versions. E.g.

    ```python
    @warn_for_wrong_dependencies_versions([("torch", "==", "1.2.0")])
    def foo(x):
        # I only work with torch `1.2.0` but another one is installed
        print(f"foo {x}")
    ```

    prints:

    ```
    Dependency torch==1.2.0 is required but found version=1.13.1,
        to fix: `pip install torch==1.2.0`
    ```

    Args:
        dependencies_versions (List[Tuple[str, str]]): List of dependencies
            we want to check, [("<package_name>", "<version_number_to_check")]
    """

    def _inner(func):
        def _wrapper(*args, **kwargs):
            print_warn_for_wrong_dependencies_versions(dependencies_versions)
            func(*args, **kwargs)

        return _wrapper

    return _inner


def normalize_yolo_model_type(model_type: str) -> str:
    model_type = model_type.replace("yolo11", "yolov11")
    model_type = model_type.replace("yolo12", "yolov12")
    return model_type


def get_model_format(model_type: str) -> str:
    """
    Get the model format for a given model type.
    Args:
        model_type (str): The model type to get the format for.

    Returns:
        str: The model format.

    Example:
        >>> get_model_format("yolov5v6n")
        "yolov5pytorch"
        >>> get_model_format("rfdetr-nano")
        "coco"
        >>> get_model_format("yolov11n")
        "yolov5pytorch"
    """
    # Prefixes extrated from modelRegistry.js in roboflow.
    model_formats = {
        "yolo": "yolov5pytorch",
        "pali": "jsonl",
        "flor": "jsonl",
        "qwen": "jsonl",
        "smol": "jsonl",
        "vit-b": "folder",
        "resn": "folder",
        "rfdetr": "coco",
        "rf-detr": "coco",
        "deep": "png-mask-semantic",
    }

    for prefix, format in model_formats.items():
        if prefix in model_type:
            return format
    return "yolov5pytorch"
