""" Example: Using the model for deepfake detection """ import torch from torchvision import transforms from PIL import Image from model import load_model import json # Load model device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') model = load_model('model_best_checkpoint.ckpt', device=device) # Load calibrated thresholds with open('thresholds_calibrated.json', 'r') as f: config = json.load(f) threshold = config['reconstruction_thresholds']['thresholds']['balanced']['value'] print(f"Using threshold: {threshold:.6f}") # Prepare image preprocessing transform = transforms.Compose([ transforms.Resize((128, 128)), transforms.ToTensor(), transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) ]) def detect_deepfake(image_path, model, threshold, device): """ Detect if an image is likely a deepfake based on reconstruction error. Args: image_path: Path to image file model: Loaded autoencoder model threshold: MSE threshold for detection device: torch device Returns: is_fake: Boolean indicating if image is likely fake error: Reconstruction error value confidence: Confidence score (0-1) """ # Load and preprocess image image = Image.open(image_path).convert('RGB') input_tensor = transform(image).unsqueeze(0).to(device) # Calculate reconstruction error with torch.no_grad(): error = model.reconstruction_error(input_tensor, reduction='none') error_value = error.item() is_fake = error_value > threshold # Calculate confidence (normalized error relative to threshold) confidence = min(abs(error_value - threshold) / threshold, 1.0) return is_fake, error_value, confidence # Example usage image_path = "test_image.jpg" is_fake, error, confidence = detect_deepfake(image_path, model, threshold, device) print(f"\nResults for: {image_path}") print(f"Reconstruction Error: {error:.6f}") print(f"Threshold: {threshold:.6f}") print(f"Classification: {'FAKE' if is_fake else 'REAL'}") print(f"Confidence: {confidence:.2%}") # Batch processing example def batch_detect(image_paths, model, threshold, device): """Process multiple images at once""" images = [] for path in image_paths: img = Image.open(path).convert('RGB') images.append(transform(img)) batch = torch.stack(images).to(device) with torch.no_grad(): errors = model.reconstruction_error(batch, reduction='none') results = [] for i, error in enumerate(errors): is_fake = error.item() > threshold results.append({ 'path': image_paths[i], 'error': error.item(), 'is_fake': is_fake }) return results # Example batch processing # image_paths = ["img1.jpg", "img2.jpg", "img3.jpg"] # results = batch_detect(image_paths, model, threshold, device) # for r in results: # print(f"{r['path']}: {'FAKE' if r['is_fake'] else 'REAL'} (error: {r['error']:.6f})")