ash12321 commited on
Commit
bb9059d
·
verified ·
1 Parent(s): 1111f68

Upload model.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. model.py +193 -0
model.py ADDED
@@ -0,0 +1,193 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ ResidualConvAutoencoder - Deepfake Detection Model
3
+ Architecture: 5-stage encoder-decoder with residual blocks
4
+ """
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+
9
+ class ResidualBlock(nn.Module):
10
+ """Residual block with two conv layers and skip connection"""
11
+ def __init__(self, channels):
12
+ super().__init__()
13
+ self.conv1 = nn.Conv2d(channels, channels, 3, padding=1)
14
+ self.bn1 = nn.BatchNorm2d(channels)
15
+ self.conv2 = nn.Conv2d(channels, channels, 3, padding=1)
16
+ self.bn2 = nn.BatchNorm2d(channels)
17
+ self.relu = nn.ReLU(inplace=True)
18
+
19
+ def forward(self, x):
20
+ residual = x
21
+ out = self.relu(self.bn1(self.conv1(x)))
22
+ out = self.bn2(self.conv2(out))
23
+ out += residual
24
+ return self.relu(out)
25
+
26
+ class ResidualConvAutoencoder(nn.Module):
27
+ """
28
+ Residual Convolutional Autoencoder for image reconstruction and deepfake detection.
29
+
30
+ Args:
31
+ latent_dim (int): Dimension of latent space (default: 512)
32
+
33
+ Input:
34
+ x: Tensor of shape (batch_size, 3, 128, 128), values in [-1, 1]
35
+
36
+ Output:
37
+ reconstructed: Tensor of shape (batch_size, 3, 128, 128), values in [-1, 1]
38
+ latent: Tensor of shape (batch_size, latent_dim)
39
+ """
40
+ def __init__(self, latent_dim=512):
41
+ super().__init__()
42
+ self.latent_dim = latent_dim
43
+
44
+ # Encoder: 128x128 -> 4x4
45
+ self.encoder = nn.Sequential(
46
+ # Stage 1: 128 -> 64
47
+ nn.Conv2d(3, 64, 4, stride=2, padding=1),
48
+ nn.BatchNorm2d(64),
49
+ nn.ReLU(inplace=True),
50
+ ResidualBlock(64),
51
+
52
+ # Stage 2: 64 -> 32
53
+ nn.Conv2d(64, 128, 4, stride=2, padding=1),
54
+ nn.BatchNorm2d(128),
55
+ nn.ReLU(inplace=True),
56
+ ResidualBlock(128),
57
+
58
+ # Stage 3: 32 -> 16
59
+ nn.Conv2d(128, 256, 4, stride=2, padding=1),
60
+ nn.BatchNorm2d(256),
61
+ nn.ReLU(inplace=True),
62
+ ResidualBlock(256),
63
+
64
+ # Stage 4: 16 -> 8
65
+ nn.Conv2d(256, 512, 4, stride=2, padding=1),
66
+ nn.BatchNorm2d(512),
67
+ nn.ReLU(inplace=True),
68
+ ResidualBlock(512),
69
+
70
+ # Stage 5: 8 -> 4
71
+ nn.Conv2d(512, 512, 4, stride=2, padding=1),
72
+ nn.BatchNorm2d(512),
73
+ nn.ReLU(inplace=True),
74
+ )
75
+
76
+ # Bottleneck
77
+ self.fc_encoder = nn.Linear(512 * 4 * 4, latent_dim)
78
+ self.fc_decoder = nn.Linear(latent_dim, 512 * 4 * 4)
79
+
80
+ # Decoder: 4x4 -> 128x128
81
+ self.decoder = nn.Sequential(
82
+ # Stage 1: 4 -> 8
83
+ nn.ConvTranspose2d(512, 512, 4, stride=2, padding=1),
84
+ nn.BatchNorm2d(512),
85
+ nn.ReLU(inplace=True),
86
+ ResidualBlock(512),
87
+
88
+ # Stage 2: 8 -> 16
89
+ nn.ConvTranspose2d(512, 256, 4, stride=2, padding=1),
90
+ nn.BatchNorm2d(256),
91
+ nn.ReLU(inplace=True),
92
+ ResidualBlock(256),
93
+
94
+ # Stage 3: 16 -> 32
95
+ nn.ConvTranspose2d(256, 128, 4, stride=2, padding=1),
96
+ nn.BatchNorm2d(128),
97
+ nn.ReLU(inplace=True),
98
+ ResidualBlock(128),
99
+
100
+ # Stage 4: 32 -> 64
101
+ nn.ConvTranspose2d(128, 64, 4, stride=2, padding=1),
102
+ nn.BatchNorm2d(64),
103
+ nn.ReLU(inplace=True),
104
+ ResidualBlock(64),
105
+
106
+ # Stage 5: 64 -> 128
107
+ nn.ConvTranspose2d(64, 3, 4, stride=2, padding=1),
108
+ nn.Tanh() # Output in [-1, 1]
109
+ )
110
+
111
+ def forward(self, x):
112
+ """
113
+ Forward pass through the autoencoder.
114
+
115
+ Args:
116
+ x: Input tensor of shape (batch_size, 3, 128, 128)
117
+
118
+ Returns:
119
+ reconstructed: Reconstructed image of shape (batch_size, 3, 128, 128)
120
+ latent: Latent representation of shape (batch_size, latent_dim)
121
+ """
122
+ # Encode
123
+ x = self.encoder(x)
124
+ x = x.view(x.size(0), -1)
125
+ latent = self.fc_encoder(x)
126
+
127
+ # Decode
128
+ x = self.fc_decoder(latent)
129
+ x = x.view(x.size(0), 512, 4, 4)
130
+ reconstructed = self.decoder(x)
131
+
132
+ return reconstructed, latent
133
+
134
+ def encode(self, x):
135
+ """Extract latent representation only"""
136
+ x = self.encoder(x)
137
+ x = x.view(x.size(0), -1)
138
+ latent = self.fc_encoder(x)
139
+ return latent
140
+
141
+ def decode(self, latent):
142
+ """Reconstruct from latent representation"""
143
+ x = self.fc_decoder(latent)
144
+ x = x.view(x.size(0), 512, 4, 4)
145
+ reconstructed = self.decoder(x)
146
+ return reconstructed
147
+
148
+ def reconstruction_error(self, x, reduction='mean'):
149
+ """
150
+ Calculate per-sample reconstruction error (MSE).
151
+ Useful for anomaly/deepfake detection.
152
+
153
+ Args:
154
+ x: Input tensor
155
+ reduction: 'mean' for average error, 'none' for per-sample errors
156
+
157
+ Returns:
158
+ Reconstruction error (MSE)
159
+ """
160
+ reconstructed, _ = self.forward(x)
161
+ error = (reconstructed - x) ** 2
162
+
163
+ if reduction == 'mean':
164
+ return error.mean()
165
+ elif reduction == 'none':
166
+ return error.view(x.size(0), -1).mean(dim=1)
167
+ else:
168
+ raise ValueError(f"Unknown reduction: {reduction}")
169
+
170
+ def load_model(checkpoint_path, device='cuda'):
171
+ """
172
+ Load pretrained model from checkpoint.
173
+
174
+ Args:
175
+ checkpoint_path: Path to .ckpt file
176
+ device: 'cuda' or 'cpu'
177
+
178
+ Returns:
179
+ model: Loaded ResidualConvAutoencoder in eval mode
180
+ """
181
+ model = ResidualConvAutoencoder(latent_dim=512)
182
+ checkpoint = torch.load(checkpoint_path, map_location=device)
183
+
184
+ if 'model_state_dict' in checkpoint:
185
+ model.load_state_dict(checkpoint['model_state_dict'])
186
+ elif 'state_dict' in checkpoint:
187
+ model.load_state_dict(checkpoint['state_dict'])
188
+ else:
189
+ model.load_state_dict(checkpoint)
190
+
191
+ model = model.to(device)
192
+ model.eval()
193
+ return model