import io
import json
import os
import time
import urllib
from typing import List, Optional, Tuple
from urllib.parse import urljoin

import requests
from PIL import Image
from requests_toolbelt.multipart.encoder import MultipartEncoder
from tqdm import tqdm

from roboflow.config import API_URL
from roboflow.util.image_utils import validate_image_path
from roboflow.util.prediction import PredictionGroup

SUPPORTED_ROBOFLOW_MODELS = ["batch-video"]

SUPPORTED_ADDITIONAL_MODELS = {
    "clip": {
        "model_id": "clip",
        "model_version": "1",
        "inference_type": "clip-embed-image",
    },
    "gaze": {
        "model_id": "gaze",
        "model_version": "1",
        "inference_type": "gaze-detection",
    },
}


class InferenceModel:
    def __init__(
        self,
        api_key,
        version_id,
        colors=None,
        *args,
        **kwargs,
    ):
        """
        Create an InferenceModel object through which you can run inference.

        Args:
            api_key (str): private roboflow api key
            version_id (str): the ID of the dataset version to use for inference
        """

        self.__api_key = api_key
        self.id = version_id

        if version_id != "BASE_MODEL":
            version_info = self.id.rsplit("/")
            self.dataset_id = version_info[1]
            self.version = version_info[2]
            self.colors = {} if colors is None else colors

    def __get_image_params(self, image_path):
        """
        Get parameters about an image (i.e. dimensions) for use in an inference request.

        Args:
            image_path (Union[str, np.ndarray]): path to image or numpy array

        Returns:
            Tuple containing a dict of querystring params and a dict of requests kwargs

        Raises:
            Exception: Image path is not valid
        """
        import numpy as np

        if isinstance(image_path, np.ndarray):
            # Convert numpy array to PIL Image
            image = Image.fromarray(image_path)
            dimensions = image.size
            image_dims = {"width": str(dimensions[0]), "height": str(dimensions[1])}
            buffered = io.BytesIO()
            image.save(buffered, quality=90, format="JPEG")
            data = MultipartEncoder(fields={"file": ("imageToUpload", buffered.getvalue(), "image/jpeg")})
            return {}, {"data": data, "headers": {"Content-Type": data.content_type}}, image_dims

        validate_image_path(image_path)

        hosted_image = urllib.parse.urlparse(image_path).scheme in ("http", "https")

        if hosted_image:
            image_dims = {"width": "Undefined", "height": "Undefined"}
            return {"image": image_path}, {}, image_dims

        image = Image.open(image_path)
        dimensions = image.size
        image_dims = {"width": str(dimensions[0]), "height": str(dimensions[1])}
        buffered = io.BytesIO()
        image.save(buffered, quality=90, format="JPEG")
        data = MultipartEncoder(fields={"file": ("imageToUpload", buffered.getvalue(), "image/jpeg")})
        return (
            {},
            {"data": data, "headers": {"Content-Type": data.content_type}},
            image_dims,
        )

    def predict(self, image_path, prediction_type=None, **kwargs):
        """
        Infers detections based on image from a specified model and image path.

        Args:
            image_path (str): path to the image you'd like to perform prediction on
            prediction_type (str): type of prediction to perform
            **kwargs: Any additional kwargs will be turned into querystring params

        Returns:
            PredictionGroup Object

        Raises:
            Exception: Image path is not valid

        Example:
            >>> import roboflow

            >>> rf = roboflow.Roboflow(api_key="")

            >>> project = rf.workspace().project("PROJECT_ID")

            >>> model = project.version("1").model

            >>> prediction = model.predict("YOUR_IMAGE.jpg")
        """
        params, request_kwargs, image_dims = self.__get_image_params(image_path)

        params["api_key"] = self.__api_key

        params.update(**kwargs)
        url = f"{self.api_url}?{urllib.parse.urlencode(params)}"  # type: ignore[attr-defined]
        response = requests.post(url, **request_kwargs)
        response.raise_for_status()

        return PredictionGroup.create_prediction_group(
            response.json(),
            image_path=image_path,
            prediction_type=prediction_type,
            image_dims=image_dims,
            colors=self.colors,
        )

    def predict_video(
        self,
        video_path: str,
        fps: int = 5,
        additional_models: Optional[List[str]] = None,
        prediction_type: str = "batch-video",
    ) -> Tuple[str, str, Optional[str]]:
        """
        Infers detections based on image from specified model and image path.

        Args:
            video_path (str): path to the video you'd like to perform prediction on
            prediction_type (str): type of the model to run
            fps (int): frames per second to run inference

        Returns:
            A list of the signed url and job id

        Example:
            >>> import roboflow

            >>> rf = roboflow.Roboflow(api_key="")

            >>> project = rf.workspace().project("PROJECT_ID")

            >>> model = project.version("1").model

            >>> job_id,signed_url,signed_url_expires = model.predict_video("video.mp4"
                ,fps=5, inference_type="object-detection")
        """

        signed_url_expires = None

        url = urljoin(API_URL, "/video_upload_signed_url?api_key=" + self.__api_key)
        if fps > 120:
            raise Exception("FPS must be less than or equal to 120.")

        if additional_models is None:
            additional_models = []

        for model in additional_models:
            if model not in SUPPORTED_ADDITIONAL_MODELS:
                raise Exception(f"Model {model} is not supported for video inference.")

        if prediction_type not in SUPPORTED_ROBOFLOW_MODELS:
            raise Exception(f"{prediction_type} is not supported for video inference.")

        model_class = self.__class__.__name__

        if model_class == "ObjectDetectionModel":
            self.type = "object-detection"
        elif model_class == "ClassificationModel":
            self.type = "classification"
        elif model_class == "InstanceSegmentationModel":
            self.type = "instance-segmentation"
        elif model_class == "GazeModel":
            self.type = "gaze-detection"
        elif model_class == "CLIPModel":
            self.type = "clip-embed-image"
        elif model_class == "KeypointDetectionModel":
            self.type = "keypoint-detection"
        else:
            raise Exception("Model type not supported for video inference.")

        payload = json.dumps(
            {
                "file_name": os.path.basename(video_path),
            }
        )

        if not video_path.startswith(("http://", "https://")):
            headers = {"Content-Type": "application/json"}

            try:
                response = requests.request("POST", url, headers=headers, data=payload)
            except Exception as e:
                raise Exception(f"Error uploading video: {e}")

            if not response.ok:
                raise Exception(f"Error uploading video: {response.text}")

            signed_url = response.json()["signed_url"]

            signed_url_expires = signed_url.split("&X-Goog-Expires")[1].split("&")[0].strip("=")

            # make a POST request to the signed URL
            headers = {"Content-Type": "application/octet-stream"}

            try:
                with open(video_path, "rb") as f:
                    video_data = f.read()
            except Exception as e:
                raise Exception(f"Error reading video: {e}")

            try:
                result = requests.put(signed_url, data=video_data, headers=headers)
            except Exception as e:
                raise Exception(f"There was an error uploading the video: {e}")

            if not result.ok:
                raise Exception(f"There was an error uploading the video: {result.text}")
        else:
            signed_url = video_path

        url = urljoin(API_URL, "/videoinfer/?api_key=" + self.__api_key)
        if model_class in ("CLIPModel", "GazeModel"):
            if model_class == "CLIPModel":
                model = "clip"
            else:
                model = "gaze"

            models = [
                {
                    "model_id": SUPPORTED_ADDITIONAL_MODELS[model]["model_id"],
                    "model_version": SUPPORTED_ADDITIONAL_MODELS[model]["model_version"],
                    "inference_type": SUPPORTED_ADDITIONAL_MODELS[model]["inference_type"],
                }
            ]
        else:
            models = [
                {
                    "model_id": self.dataset_id,
                    "model_version": self.version,
                    "inference_type": self.type,
                }
            ]

        for model in additional_models:
            models.append(SUPPORTED_ADDITIONAL_MODELS[model])

        payload = json.dumps({"input_url": signed_url, "infer_fps": fps, "models": models})

        headers = {"Content-Type": "application/json"}

        try:
            response = requests.request("POST", url, headers=headers, data=payload)
        except Exception as e:
            raise Exception(f"Error starting video inference: {e}")

        if not response.ok:
            raise Exception(f"Error starting video inference: {response.text}")

        job_id = response.json()["job_id"]

        self.job_id = job_id

        return job_id, signed_url, signed_url_expires

    def poll_for_video_results(self, job_id: Optional[str] = None) -> dict:
        """
        Polls the Roboflow API to check if video inference is complete.

        Returns:
            Inference results as a dict

        Example:
            >>> import roboflow

            >>> rf = roboflow.Roboflow(api_key="")

            >>> project = rf.workspace().project("PROJECT_ID")

            >>> model = project.version("1").model

            >>> prediction = model.predict("video.mp4")

            >>> results = model.poll_for_video_results()
        """

        if job_id is None:
            job_id = self.job_id

        url = urljoin(API_URL, "/videoinfer/?api_key=" + self.__api_key + "&job_id=" + job_id)
        try:
            response = requests.get(url, headers={"Content-Type": "application/json"})
        except Exception as e:
            raise Exception(f"Error getting video inference results: {e}")

        if not response.ok:
            raise Exception(f"Error getting video inference results: {response.text}")
        data = response.json()
        if "status" not in data:
            return {}  # No status available
        if data.get("status") > 1:
            return data  # Error
        elif data.get("status") == 1:
            return {}  # Still running
        else:  # done
            output_signed_url = data["output_signed_url"]
            inference_data = requests.get(output_signed_url, headers={"Content-Type": "application/json"})

            # frame_offset and model name are top-level keys
            return inference_data.json()

    def poll_until_video_results(self, job_id) -> dict:
        """
        Polls the Roboflow API to check if video inference is complete.

        When inference is complete, the results are returned.

        Returns:
            Inference results as a dict

        Example:
            >>> import roboflow

            >>> rf = roboflow.Roboflow(api_key="")

            >>> project = rf.workspace().project("PROJECT_ID")

            >>> model = project.version("1").model

            >>> prediction = model.predict("video.mp4")

            >>> results = model.poll_until_results()
        """
        if job_id is None:
            job_id = self.job_id

        attempts = 0
        print(f"Checking for video inference results for job {job_id} every 60s")
        while True:
            time.sleep(60)
            print(f"({attempts * 60}s): Checking for inference results")
            response = self.poll_for_video_results(job_id)
            attempts += 1

            if response != {}:
                return response

    def download(self, format="pt", location="."):
        """
        Download the weights associated with a model.

        Args:
            format (str): The format of the output.
                        - 'pt': returns a PyTorch weights file
            location (str): The location to save the weights file to
        """
        supported_formats = ["pt"]
        if format not in supported_formats:
            raise Exception(f"Unsupported format {format}. Must be one of {supported_formats}")

        workspace, project, version = self.id.rsplit("/")

        # get pt url
        pt_api_url = f"{API_URL}/{workspace}/{project}/{self.version}/ptFile"

        r = requests.get(pt_api_url, params={"api_key": self.__api_key})

        r.raise_for_status()

        pt_weights_url = r.json()["weightsUrl"]

        response = requests.get(pt_weights_url, stream=True)

        # write the zip file to the desired location
        with open(location + "/weights.pt", "wb") as f:
            total_length = int(response.headers.get("content-length"))  # type: ignore[arg-type]
            for chunk in tqdm(
                response.iter_content(chunk_size=1024),
                desc=f"Downloading weights to {location}/weights.pt",
                total=int(total_length / 1024) + 1,
            ):
                if chunk:
                    f.write(chunk)
                    f.flush()

        return
