File size: 6,149 Bytes
bb9059d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 |
"""
ResidualConvAutoencoder - Deepfake Detection Model
Architecture: 5-stage encoder-decoder with residual blocks
"""
import torch
import torch.nn as nn
class ResidualBlock(nn.Module):
"""Residual block with two conv layers and skip connection"""
def __init__(self, channels):
super().__init__()
self.conv1 = nn.Conv2d(channels, channels, 3, padding=1)
self.bn1 = nn.BatchNorm2d(channels)
self.conv2 = nn.Conv2d(channels, channels, 3, padding=1)
self.bn2 = nn.BatchNorm2d(channels)
self.relu = nn.ReLU(inplace=True)
def forward(self, x):
residual = x
out = self.relu(self.bn1(self.conv1(x)))
out = self.bn2(self.conv2(out))
out += residual
return self.relu(out)
class ResidualConvAutoencoder(nn.Module):
"""
Residual Convolutional Autoencoder for image reconstruction and deepfake detection.
Args:
latent_dim (int): Dimension of latent space (default: 512)
Input:
x: Tensor of shape (batch_size, 3, 128, 128), values in [-1, 1]
Output:
reconstructed: Tensor of shape (batch_size, 3, 128, 128), values in [-1, 1]
latent: Tensor of shape (batch_size, latent_dim)
"""
def __init__(self, latent_dim=512):
super().__init__()
self.latent_dim = latent_dim
# Encoder: 128x128 -> 4x4
self.encoder = nn.Sequential(
# Stage 1: 128 -> 64
nn.Conv2d(3, 64, 4, stride=2, padding=1),
nn.BatchNorm2d(64),
nn.ReLU(inplace=True),
ResidualBlock(64),
# Stage 2: 64 -> 32
nn.Conv2d(64, 128, 4, stride=2, padding=1),
nn.BatchNorm2d(128),
nn.ReLU(inplace=True),
ResidualBlock(128),
# Stage 3: 32 -> 16
nn.Conv2d(128, 256, 4, stride=2, padding=1),
nn.BatchNorm2d(256),
nn.ReLU(inplace=True),
ResidualBlock(256),
# Stage 4: 16 -> 8
nn.Conv2d(256, 512, 4, stride=2, padding=1),
nn.BatchNorm2d(512),
nn.ReLU(inplace=True),
ResidualBlock(512),
# Stage 5: 8 -> 4
nn.Conv2d(512, 512, 4, stride=2, padding=1),
nn.BatchNorm2d(512),
nn.ReLU(inplace=True),
)
# Bottleneck
self.fc_encoder = nn.Linear(512 * 4 * 4, latent_dim)
self.fc_decoder = nn.Linear(latent_dim, 512 * 4 * 4)
# Decoder: 4x4 -> 128x128
self.decoder = nn.Sequential(
# Stage 1: 4 -> 8
nn.ConvTranspose2d(512, 512, 4, stride=2, padding=1),
nn.BatchNorm2d(512),
nn.ReLU(inplace=True),
ResidualBlock(512),
# Stage 2: 8 -> 16
nn.ConvTranspose2d(512, 256, 4, stride=2, padding=1),
nn.BatchNorm2d(256),
nn.ReLU(inplace=True),
ResidualBlock(256),
# Stage 3: 16 -> 32
nn.ConvTranspose2d(256, 128, 4, stride=2, padding=1),
nn.BatchNorm2d(128),
nn.ReLU(inplace=True),
ResidualBlock(128),
# Stage 4: 32 -> 64
nn.ConvTranspose2d(128, 64, 4, stride=2, padding=1),
nn.BatchNorm2d(64),
nn.ReLU(inplace=True),
ResidualBlock(64),
# Stage 5: 64 -> 128
nn.ConvTranspose2d(64, 3, 4, stride=2, padding=1),
nn.Tanh() # Output in [-1, 1]
)
def forward(self, x):
"""
Forward pass through the autoencoder.
Args:
x: Input tensor of shape (batch_size, 3, 128, 128)
Returns:
reconstructed: Reconstructed image of shape (batch_size, 3, 128, 128)
latent: Latent representation of shape (batch_size, latent_dim)
"""
# Encode
x = self.encoder(x)
x = x.view(x.size(0), -1)
latent = self.fc_encoder(x)
# Decode
x = self.fc_decoder(latent)
x = x.view(x.size(0), 512, 4, 4)
reconstructed = self.decoder(x)
return reconstructed, latent
def encode(self, x):
"""Extract latent representation only"""
x = self.encoder(x)
x = x.view(x.size(0), -1)
latent = self.fc_encoder(x)
return latent
def decode(self, latent):
"""Reconstruct from latent representation"""
x = self.fc_decoder(latent)
x = x.view(x.size(0), 512, 4, 4)
reconstructed = self.decoder(x)
return reconstructed
def reconstruction_error(self, x, reduction='mean'):
"""
Calculate per-sample reconstruction error (MSE).
Useful for anomaly/deepfake detection.
Args:
x: Input tensor
reduction: 'mean' for average error, 'none' for per-sample errors
Returns:
Reconstruction error (MSE)
"""
reconstructed, _ = self.forward(x)
error = (reconstructed - x) ** 2
if reduction == 'mean':
return error.mean()
elif reduction == 'none':
return error.view(x.size(0), -1).mean(dim=1)
else:
raise ValueError(f"Unknown reduction: {reduction}")
def load_model(checkpoint_path, device='cuda'):
"""
Load pretrained model from checkpoint.
Args:
checkpoint_path: Path to .ckpt file
device: 'cuda' or 'cpu'
Returns:
model: Loaded ResidualConvAutoencoder in eval mode
"""
model = ResidualConvAutoencoder(latent_dim=512)
checkpoint = torch.load(checkpoint_path, map_location=device)
if 'model_state_dict' in checkpoint:
model.load_state_dict(checkpoint['model_state_dict'])
elif 'state_dict' in checkpoint:
model.load_state_dict(checkpoint['state_dict'])
else:
model.load_state_dict(checkpoint)
model = model.to(device)
model.eval()
return model
|