import os, math, csv
from collections import defaultdict
import tkinter as tk
from tkinter import scrolledtext, messagebox, IntVar, Frame, Label
from PIL import Image, ImageTk
import googlemaps

""" 
Potential improvements:
- Use a database to store the batch information (e.g. SQLite)
- Use JSON instead of CSV for storing batch information
- Automatically update Google Sheets / Excel Online with the category and notes
- Use a more sophisticated clustering algorithm (e.g. DBSCAN) to group images together
- Use a more sophisticated GUI library (e.g. PyQt) for a more polished interface
- Use library for calculating distances between coordinates (e.g. geopy)
"""

# Extracts iemit number for the filename.
def get_iemit(filename):
    parts = filename.split("_")
    return parts[1]

# Extracts latitude and longitude from the filename.
def get_coordinates(filename):
    parts = filename.split("_")
    latitude = float(parts[3])
    longitude = float(parts[5])
    return latitude, longitude

# Haversine distance
def calculate_distance_hav(lat1, lon1, lat2, lon2):
    # Convert latitude and longitude from degrees to radians
    lat1, lon1, lat2, lon2 = map(math.radians, [lat1, lon1, lat2, lon2])

    # Haversine formula
    dlat = lat2 - lat1
    dlon = lon2 - lon1
    a = math.sin(dlat/2)**2 + math.cos(lat1) * math.cos(lat2) * math.sin(dlon/2)**2
    c = 2 * math.asin(math.sqrt(a))
    r = 6371  # Radius of Earth in kilometers. Use 3956 for miles
    return c * r

# Euclidean distance
def calculate_distance(lat1, lon1, lat2, lon2):
    return math.sqrt((lat1 - lat2)**2 + (lon1 - lon2)**2)

# Batches images together based on their proximity in latitude and longitude with a given tolerance.
# Returns: key(latitude, longitude) : value(list of filenames)
# Potnetial improvement: generate new filenames for each batch
# Potential improvement: DBSCAN clustering
def batch_images(image_folder, tolerance=0.1):
    batches = defaultdict(list)
    centers = []

    for filename in os.listdir(image_folder):
        if filename.endswith(".png"):
            latitude, longitude = get_coordinates(filename)
            found_batch = False

            for center in centers:
                center_lat, center_lon = center
                if calculate_distance(latitude, longitude, center_lat, center_lon) < tolerance:
                    batches[center].append(filename)
                    found_batch = True
                    break

            if not found_batch:
                # Create a new batch with this image as its center
                new_center = (latitude, longitude)
                centers.append(new_center)
                batches[new_center].append(filename)

    return batches

def find_batch(filename, batches):
    for center, filenames in batches.items():
        if filename in filenames:
            return center, filenames
    return None

def read_processed_batches(csv_file_path):
    processed_batches = set()
    if os.path.exists(csv_file_path):
        with open(csv_file_path, mode='r', newline='', encoding='utf-8-sig') as csvfile:
            reader = csv.DictReader(csvfile)
            for row in reader:
                # Assuming 'batch_id' is the column that contains the unique identifier
                processed_batches.add(row['batch_id'])
    return processed_batches
    
gmaps = googlemaps.Client(key='AIzaSyBnMCEpxGUxTpGKx0qAXXEbbBpUfhtJSPQ')
image_folder = "../emit_images"
tolerance = 0.15
batches = batch_images(image_folder, tolerance)
csv_file_name = "EMIT_labelled.csv"
class BatchDisplayGUI:
    def __init__(self, master, batches):
        self.csv_file = open(csv_file_name, 'a', newline='', encoding='utf-8')
        self.csv_writer = csv.writer(self.csv_file)

        self.master = master
        self.batches = batches
        self.current_batch_index = 0
        self.current_image_index = 0
        self.image_labels = {}
        self.is_closing = False

        print(f"Number of batches left: {len(batches)}")
        self.setup_widgets()
        self.display_batch_info()

        self.master.protocol("WM_DELETE_WINDOW", self.on_window_close)

    def setup_widgets(self):
        categories = ["upstream", "midstream", "downstream", "landfill", "mine", "others","?"]

        # Create a frame for the navigation buttons
        nav_frame = tk.Frame(self.master)
        nav_frame.grid(row=4, column=1, sticky="ew", padx=5)

        # Create a frame for the inputs on the right side
        input_frame = tk.Frame(self.master)
        input_frame.grid(row=0, column=2, rowspan=4, sticky="nsew", padx=10, pady=10)

        # Configure the grid layout
        self.master.grid_columnconfigure(0, weight=1)
        self.master.grid_columnconfigure(1, weight=1)
        self.master.grid_columnconfigure(2, weight=1)

        # Text area on the left side
        self.text_area = scrolledtext.ScrolledText(self.master, wrap=tk.WORD, width=20, height=15)
        self.text_area.grid(row=0, column=0, rowspan=4, padx=10, pady=10, sticky="nsew")
        self.text_area.config(state=tk.DISABLED)

        # Frame for image in the middle
        self.image_frame = Frame(self.master, width=200, height=200)  # Assuming fixed size for demonstration
        self.image_frame.grid(row=0, column=1, rowspan=4, padx=10, pady=10, sticky="nsew")

        # Inputs on the right side
        tk.Label(input_frame, text="Category:").grid(row=0, column=0, sticky="e", padx=5)
        self.category_entry = tk.StringVar()  # Variable to hold the selected category
        self.category_entry.set(categories[6])  # Set default value
        self.category_menu = tk.OptionMenu(input_frame, self.category_entry, *categories)
        self.category_menu.grid(row=0, column=1, sticky="ew", padx=5)

        tk.Label(input_frame, text="Notes:").grid(row=1, column=0, sticky="e", padx=5)
        self.notes_entry = tk.Entry(input_frame)
        self.notes_entry.grid(row=1, column=1, sticky="ew", padx=5)

        self.emitter_var = tk.IntVar()
        tk.Checkbutton(input_frame, text="Emitter Present", variable=self.emitter_var).grid(row=2, column=0, columnspan=2, sticky="w", padx=5)

        # Navigation buttons
        self.prev_button = tk.Button(nav_frame, text="Previous Image", command=self.show_prev_image)
        self.prev_button.grid(row=0, column=0, sticky="ew", padx=5)

        self.next_image_button = tk.Button(nav_frame, text="Next Image", command=self.show_next_image)
        self.next_image_button.grid(row=0, column=1, sticky="ew", padx=5)

        self.next_batch_button = tk.Button(self.master, text="Next Batch", command=self.update_csv)
        self.next_batch_button.grid(row=5, column=1, sticky="ew", padx=5, pady=5)

        # Adjust the row configuration to make the text area and image frame expand to fill the space vertically
        self.master.grid_rowconfigure(0, weight=1)
        self.master.grid_rowconfigure(1, weight=1)
        self.master.grid_rowconfigure(2, weight=1)
        self.master.grid_rowconfigure(3, weight=1)
        nav_frame.columnconfigure(0, weight=1)
        nav_frame.columnconfigure(1, weight=1)
        input_frame.grid_columnconfigure(1, weight=1)

    def display_image(self, filename):
        # Clear previous images
        for widget in self.image_frame.winfo_children():
            widget.destroy()

        print(f"Displaying image: {get_iemit(filename)}")
        image = Image.open(os.path.join(image_folder, filename))
        image = image.resize((750, 550), Image.Resampling.LANCZOS)
        photo = ImageTk.PhotoImage(image)
        label = Label(self.image_frame, image=photo)
        label.image = photo
        label.pack()

    def show_next_image(self):
        self.save_current_label()
        self.current_image_index = min(len(batches[self.batches[self.current_batch_index]]) - 1, self.current_image_index + 1)
        filenames = batches[self.batches[self.current_batch_index]]
        filename = filenames[self.current_image_index]
        # Save previous input
        notes, emitter_present = self.image_labels.get(filename, ["", ""])
        if notes:
            self.notes_entry.insert(0, notes)
        else:
            self.notes_entry.delete(0, tk.END)
        self.emitter_var.set(1 if emitter_present == "yes" else 0)
        self.display_image(filename)

    def show_prev_image(self):
        self.save_current_label()
        self.current_image_index = max(0, self.current_image_index - 1)
        filenames = batches[self.batches[self.current_batch_index]]
        filename = filenames[self.current_image_index]
        # Save previous input
        notes, emitter_present = self.image_labels.get(filename, ["", ""])
        if notes:
            self.notes_entry.insert(0, notes)
        else:
            self.notes_entry.delete(0, tk.END)
        self.emitter_var.set(1 if emitter_present == "yes" else 0)
        
        self.display_image(filename)

    def save_current_label(self):
        current_image = batches[self.batches[self.current_batch_index]][self.current_image_index]
        notes = self.notes_entry.get()
        emitter_present = "yes" if self.emitter_var.get() else "no"
        self.image_labels[current_image] = [notes, emitter_present]
    
    def update_csv(self):
        if self.is_closing:
            return
        
        center = self.batches[self.current_batch_index] # iterate through the batch IDs (center coordinates)
        filenames = batches.get(center) # get the filenames for the corresponding batch ID (center coordinates)
        
        # Grab inputs from the user
        category = self.category_entry.get()

        for filename in filenames:
            notes, emitter_present = self.image_labels.get(filename, ["", ""])
            self.csv_writer.writerow([center, filename, category, notes, emitter_present])

        # Clear the entry fields
        self.category_entry.set("?")
        self.notes_entry.delete(0, tk.END)
        self.emitter_var.set(0)

        self.current_batch_index += 1
        self.current_image_index = 0

        # Display the next batch
        self.display_batch_info()

    
    def display_batch_info(self):
        if self.is_closing:
            return

        print(f"Current batch index: {self.current_batch_index}")

        if self.current_batch_index >= len(self.batches):
            self.text_area.delete(1.0, tk.END)
            self.text_area.insert(tk.INSERT, "No more batches to display.")
            messagebox.showinfo("No More Batches", "No more batches to display.")
            return
        
        center = self.batches[self.current_batch_index] # iterate through the batch IDs (center coordinates)
        filenames = batches.get(center) # get the filenames for the corresponding batch ID (center coordinates)
        
        # Perform a nearby search
        places_result = gmaps.places_nearby(location=(center[0], center[1]), radius=1000, open_now=False)

        # Display the next batch
        if self.current_batch_index < len(self.batches):
            self.text_area.config(state=tk.NORMAL)
            self.text_area.delete(1.0, tk.END)
            self.text_area.insert(tk.INSERT, f"Batch centered at {center}:\n")
            for filename in filenames:
                filename = get_iemit(filename)
                self.text_area.insert(tk.INSERT, f"\t- {filename}\n")
            self.text_area.insert(tk.INSERT, "\nNearby places:\n")
            for place in places_result['results']:
                types = place.get('types', [])
                types_str = ', '.join(types)  # Join the list of types into a comma-separated string
                self.text_area.insert(tk.INSERT, f"\t- {place['name']} (Types: {types_str})\n")
            self.text_area.config(state=tk.DISABLED)
            # Start by displaying the first image in the batch
            self.display_image(filenames[0])

    def on_window_close(self):
        print("Window is closing")
        self.is_closing = True
        self.csv_file.close()
        self.master.destroy()

def main():
    window = tk.Tk()
    window.title("Batch Information")

    unprocessed_batches = set(batches.keys())
    processed_batches = read_processed_batches(csv_file_name)
    # Remove processed batches from the set of unprocessed batches
    unprocessed_batches = [batch for batch in unprocessed_batches if str(batch) not in processed_batches]
    gui = BatchDisplayGUI(window, unprocessed_batches)
    window.mainloop()

if __name__ == "__main__":
    main()

# Print information about a specific batch
# image_folder = "../small"
# tolerance = 0.15
# batches = batch_images(image_folder, tolerance)
# center, filenames = find_batch("iemit_980_lat_38.8503652059909_lon_54.23517605482166_DT_20230620084414.png", batches)
# print(f"Batch centered at {center}:")
# sorted_filenames = sorted(filenames, key=lambda x: int(get_iemit(x)))
# for filename in sorted_filenames:
#     print(f"\t- {get_iemit(filename)}")

# Print information about the batches
# for center, filenames in batches.items():
#   print(f"Batch centered at ({center[0]}, {center[1]}):")
#   for filename in filenames:
#     print(f"\t- {filename}")

#print(f"Number of batches: {len(batches)}")

# Export batches to a CSV file
# with open('batches.csv', 'w', newline='') as csvfile:
#     writer = csv.writer(csvfile)
#     writer.writerow(['Center Latitude', 'Center Longitude', 'Filename'])
#     for center, filenames in batches.items():
#         for filename in filenames:
#             writer.writerow([center[0], center[1], filename])