import os
import glob
import numpy as np
import cv2
import string
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Input, Conv2D, MaxPool2D, Flatten, Dense, Dropout
from sklearn.model_selection import train_test_split

DATASET_PATH = "my_dataset" 
IMG_WIDTH = 200
IMG_HEIGHT = 50
NUM_CHARS = 5
CHAR_SET = string.ascii_lowercase + string.digits 

MODEL_SAVE_PATH = "captcha_model.h5"
CHAR_TO_INT = {char: i for i, char in enumerate(CHAR_SET)}
INT_TO_CHAR = {i: char for i, char in enumerate(CHAR_SET)}
NUM_CLASSES = len(CHAR_SET)

def load_data():
    """Loads images and labels from your dataset directory."""
    print(f"Loading data from: {DATASET_PATH}")
    image_files = glob.glob(os.path.join(DATASET_PATH, "*.png"))
    image_files += glob.glob(os.path.join(DATASET_PATH, "*.jpg"))
    
    images = []
    labels = []
    
    for file in image_files:
        filename = os.path.basename(file)
        label = filename.split(".")[0]
        
        if len(label) != NUM_CHARS:
            print(f"Skipping {filename}, label length is {len(label)}, expected {NUM_CHARS}.")
            continue

        image = cv2.imread(file, cv2.IMREAD_GRAYSCALE)
        
        if image is None:
            print(f"Skipping {filename}, could not read image.")
            continue
            
        image = cv2.resize(image, (IMG_WIDTH, IMG_HEIGHT))
        image = image.astype("float32") / 255.0
        
        image = np.expand_dims(image, axis=-1)
        
        images.append(image)
        labels.append(label)
        
    if not images:
        print(f"Error: No images found in '{DATASET_PATH}'.")
        print("Please check your DATASET_PATH setting.")
        raise ValueError("No images found!")
        
    print(f"Loaded {len(images)} images.")
    return np.array(images), labels

def preprocess_data(X, y_labels):
    """Filters data and prepares labels for training."""
    print("Preprocessing labels...")
    
    valid_indices = []
    y_processed_lists = {f'output_{i+1}': [] for i, char in enumerate(range(NUM_CHARS))}
    
    for idx, label in enumerate(y_labels):
        is_valid_label = True
        char_lists = []
        for i, char in enumerate(label):
            if char not in CHAR_TO_INT:
                print(f"Skipping label '{label}': character '{char}' is not in your CHAR_SET.")
                is_valid_label = False
                break
            
            char_int = CHAR_TO_INT[char]
            one_hot = tf.keras.utils.to_categorical(char_int, num_classes=NUM_CLASSES)
            char_lists.append(one_hot)

        if is_valid_label:
            valid_indices.append(idx)
            for i in range(NUM_CHARS):
                y_processed_lists[f'output_{i+1}'].append(char_lists[i])

    X_filtered = X[valid_indices]
    
    y_processed = {key: np.array(val) for key, val in y_processed_lists.items()}
    
    if len(X_filtered) == 0:
        print("\n--- FATAL ERROR ---")
        print("No valid data to train on.")
        print("This probably means your CHAR_SET is wrong or your filenames don't match NUM_CHARS.")
        print(f"Example: Your NUM_CHARS is {NUM_CHARS}, but your filenames might be 'abc.png' (length 3).")
        print(f"Example: Your CHAR_SET is '{CHAR_SET}', but your filenames have characters not in this set.")
        print("Please fix your settings at the top of the file and try again.")
        return None, None

    return X_filtered, y_processed

def build_model():
    """Builds the multi-output CNN model."""
    print("Building model...")
    
    input_layer = Input(shape=(IMG_HEIGHT, IMG_WIDTH, 1), name="input_image")
    
    x = Conv2D(32, (3, 3), activation='relu')(input_layer)
    x = MaxPool2D((2, 2))(x)
    x = Conv2D(64, (3, 3), activation='relu')(x)
    x = MaxPool2D((2, 2))(x)
    x = Conv2D(64, (3, 3), activation='relu')(x)
    x = MaxPool2D((2, 2))(x)
    
    x = Flatten()(x)
    x = Dense(128, activation='relu')(x)
    x = Dropout(0.5)(x)
    
    output_heads = []
    for i in range(NUM_CHARS):
        head_name = f'output_{i+1}'
        head = Dense(NUM_CLASSES, activation='softmax', name=head_name)(x)
        output_heads.append(head)
        
    model = Model(inputs=input_layer, outputs=output_heads)
    
    loss = {f'output_{i+1}': 'categorical_crossentropy' for i in range(NUM_CHARS)}
    
    metrics = {f'output_{i+1}': 'accuracy' for i in range(NUM_CHARS)}
    
    model.compile(optimizer='adam', loss=loss, metrics=metrics)

    print(model.summary())
    return model

def main():
    X, y_labels = load_data()
    X_filtered, y_processed = preprocess_data(X, y_labels)
    
    if X_filtered is None:
        return

    X_train, X_val, *y_splits = train_test_split(
        X_filtered, *y_processed.values(), test_size=0.2, random_state=42
    )

    keys = list(y_processed.keys())
    y_train = {}
    y_val = {}
    
    for i in range(len(keys)):
        key = keys[i]
        y_train[key] = y_splits[i * 2]  
        y_val[key] = y_splits[i * 2 + 1]

    print(f"Total valid images: {len(X_filtered)}")
    print(f"Training data shape: {X_train.shape}")
    print(f"Validation data shape: {X_val.shape}")
    
    model = build_model()
    
    callbacks = [
        tf.keras.callbacks.EarlyStopping(monitor='val_loss', patience=5, restore_best_weights=True),
        tf.keras.callbacks.ModelCheckpoint(MODEL_SAVE_PATH, monitor='val_loss', save_best_only=True)
    ]
    
    print("Starting training...")
    history = model.fit(
        X_train,
        y_train,
        validation_data=(X_val, y_val),
        epochs=30,  
        batch_size=32,
        callbacks=callbacks
    )
    
    print("\nTraining complete!")
    print(f"Best model saved to {MODEL_SAVE_PATH}")

if __name__ == "__main__":
    main()
