# predict_from_camera_rtsp_realtime.py

import cv2
import json
import time
from threading import Thread
from queue import Queue, Empty
import tempfile
import os
from speciesnet import SpeciesNet, DEFAULT_MODEL
import numpy as np

import paho.mqtt.client as mqtt
import time

import datetime 

import json
import os
import threading
import http.server
import socketserver
import uuid  # Import for generating unique IDs
import glob  # Import for finding files to clean up

# --- Configuration ---

#PC_IP_ADDRESS = "fluffinator.jeffsawyer.us" 
PC_IP_ADDRESS = "192.168.215.105" 
WEB_SERVER_PORT = 6969
IMAGE_DIRECTORY = "images" 

IMAGE_URL_BASE = "https://fluffinator.jeffsawyer.us/"

# --- MQTT Configuration ---
# Replace with your MQTT broker's IP address or hostname
BROKER_ADDRESS = "hainternal.jeffsawyer.us" 
BROKER_PORT = 1883
# This is the topic Home Assistant will listen to.
MQTT_TOPIC = "home/backyard/fluffinator"


# --- Authentication Details (ADD THESE) ---
# Replace with your actual username and password
MQTT_USERNAME = "fluffmqtt"
MQTT_PASSWORD = "fluff1234"


# --- Payload Configuration ---
# These are the messages the script will send.
# Home Assistant will be configured to treat these as "ON" and "OFF".
STATUS_ON = "detected"
STATUS_OFF = "clear"


def start_web_server():
    """Starts a simple HTTP server in a separate thread."""
    os.chdir(os.path.dirname(os.path.abspath(__file__)))
    Handler = http.server.SimpleHTTPRequestHandler
    httpd = socketserver.TCPServer(("", WEB_SERVER_PORT), Handler)
    print(f"Serving images at http://{PC_IP_ADDRESS}:{WEB_SERVER_PORT}/{IMAGE_DIRECTORY}/")
    thread = threading.Thread(target=httpd.serve_forever)
    thread.daemon = True
    thread.start()

# --- UPDATED FUNCTION ---
def generate_temp_filepath(directory):
    """
    Generates a unique, full file path without creating the file.

    Args:
        directory (str): The directory where the file will be located.

    Returns:
        str: A unique, absolute path (e.g., 'C:\\Users\\...\\images\\capture_a1b2c3d4.jpg').
    """

    timestamp = datetime.datetime.now().strftime('%Y%m%d_%H%M%S')
    unique_filename = f"capture_{timestamp}_{uuid.uuid4().hex[:8]}.jpg"
    return os.path.join(directory, unique_filename)

def publish_detection(status, motion_type="none", image_filename=None):
    """Publishes detection data, including the image URL, as a JSON payload."""
    client = mqtt.Client(mqtt.CallbackAPIVersion.VERSION2, "Fluffinator")
    
    client.username_pw_set(MQTT_USERNAME, MQTT_PASSWORD)

    image_url = "none"
    if image_filename:
        image_url = f"{IMAGE_URL_BASE}/{IMAGE_DIRECTORY}/{image_filename}"

    payload = {
        "status": status,
        "type": motion_type,
        "image_url": image_url
    }
    json_payload = json.dumps(payload)

    try:
        client.connect(BROKER_ADDRESS, BROKER_PORT)
        client.publish(MQTT_TOPIC, json_payload)
        client.disconnect()
        print(f"Successfully published: {json_payload}")
    except Exception as e:
        print(f"An error occurred: {e}")


class ThreadedCamera:
    """
    A class that uses a dedicated thread and a single-item queue to ensure
    the latest frame is always available without any lag.
    """
    def __init__(self, src=0):
        self.capture = cv2.VideoCapture(src)
        if not self.capture.isOpened():
            print(f"Error: Could not open video stream at {src}")
            raise IOError("Cannot open video stream")
            
        # Use a queue of size 1 to store the latest frame
        self.q = Queue(maxsize=1)
        self.stopped = False
        
        # Start the thread to read frames from the video stream
        self.thread = Thread(target=self.update, args=())
        self.thread.daemon = True
        self.thread.start()

    def update(self):
        """The main loop of the reading thread."""
        while not self.stopped:
            ret, frame = self.capture.read()
            if not ret:
                self.stop()
                return

            # If the queue is full, it means the main thread hasn't consumed
            # the last frame yet. We remove it to make space for the new one.
            if self.q.full():
                try:
                    self.q.get_nowait() # Discard the old frame
                except Empty:
                    pass
            
            # Put the new, most recent frame onto the queue.
            self.q.put(frame)

    def read(self):
        """Return the latest frame from the queue."""
        # This will block until a frame is available.
        return self.q.get()

    def stop(self):
        """Signal the thread to stop."""
        self.stopped = True
        self.thread.join()
        self.capture.release()

def predictor_thread(model, frame_queue, result_queue, result_frame_q):
    """A thread that takes frames from a queue and runs prediction."""
    while True:
        frame = frame_queue.get()
        if frame is None: # Signal to exit
            break

        with tempfile.NamedTemporaryFile(suffix=".jpg", delete=False) as tmp:
            temp_filename = tmp.name
            cv2.imwrite(temp_filename, frame)
        
        try:
            # --- MODIFICATION: Add location data to the instance ---
            instances_dict = {
                "instances": [
                    {
                        "filepath": temp_filename,
                        "country": "USA",
                        "admin1_region": "CA"
                    }
                ]
            }
            # --------------------------------------------------------

            prediction = model.predict(
                instances_dict=instances_dict,
                batch_size=1,
                progress_bars=False
            )
            result_queue.put(prediction)
            result_frame_q.put(frame)
        finally:
            os.remove(temp_filename)


def main():
    print("Initializing SpeciesNet model...")
    model = SpeciesNet(DEFAULT_MODEL, components="all", geofence=True)
    print("Model loaded. Initializing real-time stream...")

    rtsp_url = "rtsp://fluffcam:fluff1234@192.168.215.40/stream1"
    try:
        video_stream = ThreadedCamera(rtsp_url)
    except IOError as e:
        print(e)
        return

    frame_for_prediction_q = Queue(maxsize=1)
    
    prediction_result_q = Queue(maxsize=1)
    prediction_result_frame_q = Queue(maxsize=1)
    pred_thread = Thread(target=predictor_thread, args=(model, frame_for_prediction_q, prediction_result_q, prediction_result_frame_q))
    pred_thread.daemon = True
    pred_thread.start()


    if not os.path.exists(IMAGE_DIRECTORY):
        os.makedirs(IMAGE_DIRECTORY)
        print(f"Created directory: ./{IMAGE_DIRECTORY}")

    # Start the web server
    print("\nStarting webserver.")
    start_web_server()

    print("\nStarting motion based prediction.")
    
    latest_prediction_result = {}

    display_info = None
    display_scale_factor = 0.25
    save_image = 0

    background_frame = None
    background_frame_timestamp = None
    min_contour_area = 1000  # Minimum size of motion to consider (pixels)
    last_motion_time = 0
    prediction_cooldown = 5.0 # Seconds to wait before re-triggering
    frame_to_display = None

    
    publish_detection(status=STATUS_OFF)

    while True:
        # The .read() call now guarantees the MOST RECENT frame
        frame = video_stream.read()
        if frame is None:
            continue

        real_time_frame = frame.copy()

 
        if frame_to_display is None:
            frame_to_display = frame.copy()

            h, w, _ = frame_to_display.shape
            timestamp_text = datetime.datetime.now().strftime("%A %d %B %Y %I:%M:%S%p")
            cv2.rectangle(frame_to_display, (0,0), (w,70), (255,255,255), -1)
            cv2.putText(frame_to_display, "Initial Load: " + timestamp_text, (20, 60), cv2.FONT_HERSHEY_SIMPLEX, 2, (0, 0, 0), 5)
            new_width, new_height = int(w * display_scale_factor), int(h * display_scale_factor)
            frame_to_display = cv2.resize(frame_to_display, (new_width, new_height))

            full_image_path = generate_temp_filepath(IMAGE_DIRECTORY)

            cv2.imwrite(full_image_path, frame_to_display)
            filename_for_url = os.path.basename(full_image_path)

            #publish_detection(status=STATUS_ON, motion_type=f"{animal_name.title()}")
            publish_detection(
                status=STATUS_OFF, 
                motion_type="Load", 
                image_filename=filename_for_url
            )

            print("Saved image: ", full_image_path)
            
        
        # 1. Prepare frame for motion detection
        gray = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY)
        gray = cv2.GaussianBlur(gray, (21, 21), 0)

        # 2. Initialize the background frame on the first run
        ## this should be updated every once in a while
        
        if background_frame is None:
            background_frame = gray.copy().astype("float")
            background_frame_timestamp = datetime.datetime.now().strftime("%A %d %B %Y %I:%M:%S%p")
            continue

        # 3. Update the running average of the background
        cv2.accumulateWeighted(gray, background_frame, 0.1)
        frame_delta = cv2.absdiff(gray, cv2.convertScaleAbs(background_frame))

        # 4. Threshold the delta image to get regions of significant change
        thresh = cv2.threshold(frame_delta, 30, 255, cv2.THRESH_BINARY)[1]
        thresh = cv2.dilate(thresh, None, iterations=2)
        
        # 5. Find contours (i.e., areas of motion)
        contours, _ = cv2.findContours(thresh.copy(), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)

        motion_detected = False
        for c in contours:
            if cv2.contourArea(c) > min_contour_area:
                motion_detected = True
                break # Found motion, no need to check other contours

        # 6. Trigger prediction if motion is detected and cooldown has passed
        current_time = time.time()
        if motion_detected and (current_time - last_motion_time) > prediction_cooldown:
            last_motion_time = current_time
            #print(f"Motion detected at {time.ctime()}! Sending frame for analysis.")
            background_frame = None
            
            if not frame_for_prediction_q.full():
                frame_for_prediction_q.put(frame.copy())


        # --- NEW: Check for new results and parse them for display ---
        if not prediction_result_q.empty():
            latest_prediction_result = prediction_result_q.get()
            frame = prediction_result_frame_q.get()
            #print("--- New Prediction Result ---")
            #print(json.dumps(latest_prediction_result, ensure_ascii=False, indent=4))
            
            try:
                prediction_data = latest_prediction_result['predictions'][0]
                full_prediction_string = prediction_data['prediction']
                animal_name = full_prediction_string.split(';')[-1]
                score = prediction_data['prediction_score']
                display_text = f"{animal_name.title()}: {score:.2%}"

                


                # --- NEW LOGIC: Check if the prediction is 'blank' ---
                if (animal_name.lower() == 'blank' or animal_name.lower() == 'vehicle'  or animal_name.lower() == 'no cv result') or (score < 0.60):
                    if (score < 0.60):
                        print("Prediction Dropped (low score): ", display_text)

                else:
                    print("Prediction: ", display_text)
                    if 'classifications' in prediction_data and 'classes' in prediction_data['classifications']:
                        classes_list = prediction_data['classifications']['classes']
                        scores_list = prediction_data['classifications']['scores']
                    
                        print("Classifier Results:")
                        # Use zip to iterate through both lists simultaneously
                        for classification_string, score in zip(classes_list, scores_list):
                            common_name = classification_string.split(';')[-1]
                            # Print the name and its corresponding score, formatted as a percentage
                            print(f"- {common_name}: {score:.2%}")
                            
                    print(json.dumps(latest_prediction_result, ensure_ascii=False, indent=4))

                    save_image = 1

                    


                    
                    box = None
                    if prediction_data['detections']:
                        bbox_norm = prediction_data['detections'][0]['bbox']
                        h, w, _ = frame.shape
                        xmin, ymin = int(bbox_norm[0] * w), int(bbox_norm[1] * h)
                        box_w, box_h = int(bbox_norm[2] * w), int(bbox_norm[3] * h)
                        start_point, end_point = (xmin, ymin), (xmin + box_w, ymin + box_h)
                        box = (start_point, end_point)

                    display_info = {'text': display_text, 'box': box}
                # --------------------------------------------------------

            except (KeyError, IndexError) as e:
                print(f"Could not parse prediction JSON: {e}")
                display_info = None

        if display_info:
            if display_info['box']:
                start_p, end_p = display_info['box']
                cv2.rectangle(frame, start_p, end_p, (0, 255, 0), 2)
                label_y_pos = start_p[1] - 10 if start_p[1] > 20 else start_p[1] + 20
                cv2.putText(frame, display_info['text'], (start_p[0], label_y_pos), 
                            cv2.FONT_HERSHEY_SIMPLEX, 0.8, (0, 255, 0), 2)
                if save_image:
                    

                    h, w, _ = frame.shape
                    timestamp_text = datetime.datetime.now().strftime("%A %d %B %Y %I:%M:%S%p")
                    cv2.rectangle(frame, (0,0), (w,70), (255,255,255), -1)
                    cv2.putText(frame, "Detected: " + timestamp_text, (20, 60), cv2.FONT_HERSHEY_SIMPLEX, 2, (0, 0, 0), 5)
                    new_width, new_height = int(w * display_scale_factor), int(h * display_scale_factor)
                    frame_to_display = cv2.resize(frame, (new_width, new_height))


                    full_image_path = generate_temp_filepath(IMAGE_DIRECTORY)

                    cv2.imwrite(full_image_path, frame)
                    filename_for_url = os.path.basename(full_image_path)

                    #publish_detection(status=STATUS_ON, motion_type=f"{animal_name.title()}")
                    publish_detection(
                        status=STATUS_ON, 
                        motion_type=f"{animal_name.title()}", 
                        image_filename=filename_for_url
                    )

                    save_image = 0
                    print("Saved image: ", full_image_path)
            #else:
            #    cv2.putText(frame, display_info['text'], (10, 60), 
            #                cv2.FONT_HERSHEY_SIMPLEX, 0.8, (0, 255, 0), 2)
                


        h, w, _ = real_time_frame.shape
        timestamp_text = datetime.datetime.now().strftime("%A %d %B %Y %I:%M:%S%p")
        cv2.rectangle(real_time_frame, (0,0), (w,130), (255,255,255), -1)
        cv2.putText(real_time_frame, "Captured:    " + timestamp_text, (20, 60), cv2.FONT_HERSHEY_SIMPLEX, 2, (0, 0, 0), 5)
        cv2.putText(real_time_frame, "Last Motion: " + background_frame_timestamp, (20, 120), cv2.FONT_HERSHEY_SIMPLEX, 2, (0, 0, 0), 5)
        
        
        
        new_width, new_height = int(w * display_scale_factor), int(h * display_scale_factor)
        real_time_frame_to_display = cv2.resize(real_time_frame, (new_width, new_height))


        combined_image = np.hstack((real_time_frame_to_display, frame_to_display))
        #combined_image2 = np.hstack((real_time_frame_to_display, background_frame_to_display))
        #combined_image = np.vstack((combined_image1, combined_image2))

        # Display the single combined image in one window
        cv2.imshow('Combined Window', combined_image)

        #cv2.imshow('Real-time Feed', real_time_frame_to_display)  


        if cv2.waitKey(1) & 0xFF == ord('q'):
            break

    print("Exiting...")
    frame_for_prediction_q.put(None) # Signal predictor thread to stop
    pred_thread.join()
    video_stream.stop()
    cv2.destroyAllWindows()

if __name__ == "__main__":
    main()
