import os
import sys  
import numpy as np
import cv2
import string
import tensorflow as tf

IMG_WIDTH = 200
IMG_HEIGHT = 50
NUM_CHARS = 5
CHAR_SET = string.ascii_lowercase + string.digits 

MODEL_PATH = "captcha_model.h5"
INT_TO_CHAR = {i: char for i, char in enumerate(CHAR_SET)}
NUM_CLASSES = len(CHAR_SET)

def load_prediction_model():
    """Loads the trained Keras model."""
    if not os.path.exists(MODEL_PATH):
        print(f"Error: Model file not found at {MODEL_PATH}")
        print("Please run train_model.py first.")
        return None
    
    print(f"Loading model from {MODEL_PATH}...")
    os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
    model = tf.keras.models.load_model(MODEL_PATH, compile=False)
    print("Model loaded successfully.")
    return model
def preprocess_image(image_path):
    """Loads and preprocesses a single image for prediction."""
    image = cv2.imread(image_path, cv2.IMREAD_GRAYSCALE)
    if image is None:
        print(f"Error: Could not read image at {image_path}")
        return None
        
    image = cv2.resize(image, (IMG_WIDTH, IMG_HEIGHT))
    image = image.astype("float32") / 255.0
    image = np.expand_dims(image, axis=-1) 
    image = np.expand_dims(image, axis=0)  
    return image

def decode_prediction(prediction):
    """Converts the model's (softmax) output into a human-readable string."""
    predicted_string = ""
    
    for output in prediction:
        char_index = np.argmax(output, axis=1)[0]
        if char_index in INT_TO_CHAR:
            predicted_string += INT_TO_CHAR[char_index]
        else:
            predicted_string += "?" 
        
    return predicted_string

def main():
    if len(sys.argv) < 2:
        print("\nError: No image file provided.")
        print("Please specify which image you want to predict.")
        print("\n--- HOW TO USE ---")
        print(f"Example: py {sys.argv[0]} my_dataset\\2b827.png")
        print(f"   (You can also drag and drop your image file onto the terminal)")
        return

    image_path_from_user = sys.argv[1]

    if not os.path.exists(image_path_from_user):
        print(f"\nError: File not found at '{image_path_from_user}'")
        return

    model = load_prediction_model()
    if model is None:
        return
    true_label = os.path.basename(image_path_from_user).split(".")[0]
    print(f"\n--- Testing on Specific Image ---")
    print(f"File: {image_path_from_user}")
    print(f"True Label (from filename): {true_label}")
    
    image_data = preprocess_image(image_path_from_user)
    if image_data is None:
        return

    tf.get_logger().setLevel('ERROR')
    prediction = model.predict(image_data)
    predicted_label = decode_prediction(prediction)
    
    print(f"Predicted Label: {predicted_label}")
    if true_label == predicted_label:
        print("Result: CORRECT!")
    else:
        print("Result: INCORRECT")
if __name__ == "__main__":
    main()