charlesnchr commited on
Commit
07aa057
·
1 Parent(s): 0bf9103

Added architecture definitions

Browse files
Files changed (3) hide show
  1. archs/swin3d_rcab.py +881 -0
  2. archs/swinir_rcab.py +1296 -0
  3. requirements.txt +2 -0
archs/swin3d_rcab.py ADDED
@@ -0,0 +1,881 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ import torch.utils.checkpoint as checkpoint
5
+ import numpy as np
6
+ from timm.models.layers import DropPath, trunc_normal_
7
+
8
+ # from mmcv.runner import load_checkpoint
9
+ # from mmaction.utils import get_root_logger
10
+ # from ..builder import BACKBONES
11
+
12
+ from functools import reduce, lru_cache
13
+ from operator import mul
14
+ from einops import rearrange
15
+ import sys
16
+
17
+ class Upsample(nn.Sequential):
18
+ """Upsample module.
19
+
20
+ Args:
21
+ scale (int): Scale factor. Supported scales: 2^n and 3.
22
+ num_feat (int): Channel number of intermediate features.
23
+ """
24
+
25
+ def __init__(self, scale, num_feat):
26
+ m = []
27
+ if (scale & (scale - 1)) == 0: # scale = 2^n
28
+ for _ in range(int(math.log(scale, 2))):
29
+ m.append(nn.Conv2d(num_feat, 4 * num_feat, 3, 1, 1))
30
+ m.append(nn.PixelShuffle(2))
31
+ elif scale == 3:
32
+ m.append(nn.Conv2d(num_feat, 9 * num_feat, 3, 1, 1))
33
+ m.append(nn.PixelShuffle(3))
34
+ else:
35
+ raise ValueError(f'scale {scale} is not supported. ' 'Supported scales: 2^n and 3.')
36
+ super(Upsample, self).__init__(*m)
37
+
38
+
39
+
40
+ def make_layer(basic_block, num_basic_block, **kwarg):
41
+ """Make layers by stacking the same blocks.
42
+
43
+ Args:
44
+ basic_block (nn.module): nn.module class for basic block.
45
+ num_basic_block (int): number of blocks.
46
+
47
+ Returns:
48
+ nn.Sequential: Stacked blocks in nn.Sequential.
49
+ """
50
+ layers = []
51
+ for _ in range(num_basic_block):
52
+ layers.append(basic_block(**kwarg))
53
+ return nn.Sequential(*layers)
54
+
55
+
56
+ class ChannelAttention(nn.Module):
57
+ """Channel attention used in RCAN.
58
+
59
+ Args:
60
+ num_feat (int): Channel number of intermediate features.
61
+ squeeze_factor (int): Channel squeeze factor. Default: 16.
62
+ """
63
+
64
+ def __init__(self, num_feat, squeeze_factor=16):
65
+ super(ChannelAttention, self).__init__()
66
+ self.attention = nn.Sequential(
67
+ nn.AdaptiveAvgPool2d(1), nn.Conv2d(num_feat, num_feat // squeeze_factor, 1, padding=0),
68
+ nn.ReLU(inplace=True), nn.Conv2d(num_feat // squeeze_factor, num_feat, 1, padding=0), nn.Sigmoid())
69
+
70
+ def forward(self, x):
71
+ y = self.attention(x)
72
+ return x * y
73
+
74
+
75
+ class RCAB(nn.Module):
76
+ """Residual Channel Attention Block (RCAB) used in RCAN.
77
+
78
+ Args:
79
+ num_feat (int): Channel number of intermediate features.
80
+ squeeze_factor (int): Channel squeeze factor. Default: 16.
81
+ res_scale (float): Scale the residual. Default: 1.
82
+ """
83
+
84
+ def __init__(self, num_feat, squeeze_factor=16, res_scale=1):
85
+ super(RCAB, self).__init__()
86
+ self.res_scale = res_scale
87
+
88
+ self.rcab = nn.Sequential(
89
+ nn.Conv2d(num_feat, num_feat, 3, 1, 1), nn.ReLU(True), nn.Conv2d(num_feat, num_feat, 3, 1, 1),
90
+ ChannelAttention(num_feat, squeeze_factor))
91
+
92
+ def forward(self, x):
93
+ res = self.rcab(x) * self.res_scale
94
+ return res + x
95
+
96
+
97
+ class ResidualGroup(nn.Module):
98
+ """Residual Group of RCAB.
99
+
100
+ Args:
101
+ num_feat (int): Channel number of intermediate features.
102
+ num_block (int): Block number in the body network.
103
+ squeeze_factor (int): Channel squeeze factor. Default: 16.
104
+ res_scale (float): Scale the residual. Default: 1.
105
+ """
106
+
107
+ def __init__(self, num_feat, num_block, squeeze_factor=16, res_scale=1):
108
+ super(ResidualGroup, self).__init__()
109
+
110
+ self.residual_group = make_layer(
111
+ RCAB, num_block, num_feat=num_feat, squeeze_factor=squeeze_factor, res_scale=res_scale)
112
+ self.conv = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
113
+
114
+ def forward(self, x):
115
+ res = self.conv(self.residual_group(x))
116
+ return res + x
117
+
118
+
119
+ class Mlp(nn.Module):
120
+ """ Multilayer perceptron."""
121
+
122
+ def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
123
+ super().__init__()
124
+ out_features = out_features or in_features
125
+ hidden_features = hidden_features or in_features
126
+ self.fc1 = nn.Linear(in_features, hidden_features)
127
+ self.act = act_layer()
128
+ self.fc2 = nn.Linear(hidden_features, out_features)
129
+ self.drop = nn.Dropout(drop)
130
+
131
+ def forward(self, x):
132
+ x = self.fc1(x)
133
+ x = self.act(x)
134
+ x = self.drop(x)
135
+ x = self.fc2(x)
136
+ x = self.drop(x)
137
+ return x
138
+
139
+
140
+ def window_partition(x, window_size):
141
+ """
142
+ Args:
143
+ x: (B, D, H, W, C)
144
+ window_size (tuple[int]): window size
145
+
146
+ Returns:
147
+ windows: (B*num_windows, window_size*window_size, C)
148
+ """
149
+ B, D, H, W, C = x.shape
150
+ x = x.view(B, D // window_size[0], window_size[0], H // window_size[1], window_size[1], W // window_size[2], window_size[2], C)
151
+ windows = x.permute(0, 1, 3, 5, 2, 4, 6, 7).contiguous().view(-1, reduce(mul, window_size), C)
152
+ return windows
153
+
154
+
155
+ def window_reverse(windows, window_size, B, D, H, W):
156
+ """
157
+ Args:
158
+ windows: (B*num_windows, window_size, window_size, C)
159
+ window_size (tuple[int]): Window size
160
+ H (int): Height of image
161
+ W (int): Width of image
162
+
163
+ Returns:
164
+ x: (B, D, H, W, C)
165
+ """
166
+ x = windows.view(B, D // window_size[0], H // window_size[1], W // window_size[2], window_size[0], window_size[1], window_size[2], -1)
167
+ x = x.permute(0, 1, 4, 2, 5, 3, 6, 7).contiguous().view(B, D, H, W, -1)
168
+ return x
169
+
170
+
171
+
172
+
173
+ def get_window_size(x_size, window_size, shift_size=None):
174
+ use_window_size = list(window_size)
175
+ if shift_size is not None:
176
+ use_shift_size = list(shift_size)
177
+ for i in range(len(x_size)):
178
+ if x_size[i] <= window_size[i]:
179
+ use_window_size[i] = x_size[i]
180
+ if shift_size is not None:
181
+ use_shift_size[i] = 0
182
+
183
+ if shift_size is None:
184
+ return tuple(use_window_size)
185
+ else:
186
+ return tuple(use_window_size), tuple(use_shift_size)
187
+
188
+
189
+ class WindowAttention3D(nn.Module):
190
+ """ Window based multi-head self attention (W-MSA) module with relative position bias.
191
+ It supports both of shifted and non-shifted window.
192
+ Args:
193
+ dim (int): Number of input channels.
194
+ window_size (tuple[int]): The temporal length, height and width of the window.
195
+ num_heads (int): Number of attention heads.
196
+ qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
197
+ qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set
198
+ attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0
199
+ proj_drop (float, optional): Dropout ratio of output. Default: 0.0
200
+ """
201
+
202
+ def __init__(self, dim, window_size, num_heads, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.):
203
+
204
+ super().__init__()
205
+ self.dim = dim
206
+ self.window_size = window_size # Wd, Wh, Ww
207
+ self.num_heads = num_heads
208
+ head_dim = dim // num_heads
209
+ self.scale = qk_scale or head_dim ** -0.5
210
+
211
+ # define a parameter table of relative position bias
212
+ self.relative_position_bias_table = nn.Parameter(
213
+ torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1) * (2 * window_size[2] - 1), num_heads)) # 2*Wd-1 * 2*Wh-1 * 2*Ww-1, nH
214
+
215
+ # get pair-wise relative position index for each token inside the window
216
+ coords_d = torch.arange(self.window_size[0])
217
+ coords_h = torch.arange(self.window_size[1])
218
+ coords_w = torch.arange(self.window_size[2])
219
+ coords = torch.stack(torch.meshgrid(coords_d, coords_h, coords_w)) # 3, Wd, Wh, Ww
220
+ coords_flatten = torch.flatten(coords, 1) # 3, Wd*Wh*Ww
221
+ relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 3, Wd*Wh*Ww, Wd*Wh*Ww
222
+ relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wd*Wh*Ww, Wd*Wh*Ww, 3
223
+ relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0
224
+ relative_coords[:, :, 1] += self.window_size[1] - 1
225
+ relative_coords[:, :, 2] += self.window_size[2] - 1
226
+
227
+ relative_coords[:, :, 0] *= (2 * self.window_size[1] - 1) * (2 * self.window_size[2] - 1)
228
+ relative_coords[:, :, 1] *= (2 * self.window_size[2] - 1)
229
+ relative_position_index = relative_coords.sum(-1) # Wd*Wh*Ww, Wd*Wh*Ww
230
+ self.register_buffer("relative_position_index", relative_position_index)
231
+
232
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
233
+ self.attn_drop = nn.Dropout(attn_drop)
234
+ self.proj = nn.Linear(dim, dim)
235
+ self.proj_drop = nn.Dropout(proj_drop)
236
+
237
+ trunc_normal_(self.relative_position_bias_table, std=.02)
238
+ self.softmax = nn.Softmax(dim=-1)
239
+
240
+ def forward(self, x, mask=None):
241
+ """ Forward function.
242
+ Args:
243
+ x: input features with shape of (num_windows*B, N, C)
244
+ mask: (0/-inf) mask with shape of (num_windows, N, N) or None
245
+ """
246
+ B_, N, C = x.shape
247
+ qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
248
+ q, k, v = qkv[0], qkv[1], qkv[2] # B_, nH, N, C
249
+
250
+ q = q * self.scale
251
+ attn = q @ k.transpose(-2, -1)
252
+
253
+ relative_position_bias = self.relative_position_bias_table[self.relative_position_index[:N, :N].reshape(-1)].reshape(
254
+ N, N, -1) # Wd*Wh*Ww,Wd*Wh*Ww,nH
255
+ relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wd*Wh*Ww, Wd*Wh*Ww
256
+ attn = attn + relative_position_bias.unsqueeze(0) # B_, nH, N, N
257
+
258
+ if mask is not None:
259
+ nW = mask.shape[0]
260
+ attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)
261
+ attn = attn.view(-1, self.num_heads, N, N)
262
+ attn = self.softmax(attn)
263
+ else:
264
+ attn = self.softmax(attn)
265
+
266
+ attn = self.attn_drop(attn)
267
+
268
+ x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
269
+ x = self.proj(x)
270
+ x = self.proj_drop(x)
271
+ return x
272
+
273
+
274
+ class SwinTransformerBlock3D(nn.Module):
275
+ """ Swin Transformer Block.
276
+
277
+ Args:
278
+ dim (int): Number of input channels.
279
+ num_heads (int): Number of attention heads.
280
+ window_size (tuple[int]): Window size.
281
+ shift_size (tuple[int]): Shift size for SW-MSA.
282
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
283
+ qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
284
+ qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
285
+ drop (float, optional): Dropout rate. Default: 0.0
286
+ attn_drop (float, optional): Attention dropout rate. Default: 0.0
287
+ drop_path (float, optional): Stochastic depth rate. Default: 0.0
288
+ act_layer (nn.Module, optional): Activation layer. Default: nn.GELU
289
+ norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
290
+ """
291
+
292
+ def __init__(self, dim, num_heads, window_size=(2,7,7), shift_size=(0,0,0),
293
+ mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0.,
294
+ act_layer=nn.GELU, norm_layer=nn.LayerNorm, use_checkpoint=False):
295
+ super().__init__()
296
+ self.dim = dim
297
+ self.num_heads = num_heads
298
+ self.window_size = window_size
299
+ self.shift_size = shift_size
300
+ self.mlp_ratio = mlp_ratio
301
+ self.use_checkpoint=use_checkpoint
302
+
303
+ assert 0 <= self.shift_size[0] < self.window_size[0], "shift_size must in 0-window_size"
304
+ assert 0 <= self.shift_size[1] < self.window_size[1], "shift_size must in 0-window_size"
305
+ assert 0 <= self.shift_size[2] < self.window_size[2], "shift_size must in 0-window_size"
306
+
307
+ self.norm1 = norm_layer(dim)
308
+ self.attn = WindowAttention3D(
309
+ dim, window_size=self.window_size, num_heads=num_heads,
310
+ qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
311
+
312
+ self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
313
+ self.norm2 = norm_layer(dim)
314
+ mlp_hidden_dim = int(dim * mlp_ratio)
315
+ self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
316
+
317
+ def forward_part1(self, x, mask_matrix):
318
+ B, D, H, W, C = x.shape
319
+ window_size, shift_size = get_window_size((D, H, W), self.window_size, self.shift_size)
320
+
321
+ x = self.norm1(x)
322
+ # pad feature maps to multiples of window size
323
+ pad_l = pad_t = pad_d0 = 0
324
+ pad_d1 = (window_size[0] - D % window_size[0]) % window_size[0]
325
+ pad_b = (window_size[1] - H % window_size[1]) % window_size[1]
326
+ pad_r = (window_size[2] - W % window_size[2]) % window_size[2]
327
+ x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b, pad_d0, pad_d1))
328
+ _, Dp, Hp, Wp, _ = x.shape
329
+ # cyclic shift
330
+ if any(i > 0 for i in shift_size):
331
+ shifted_x = torch.roll(x, shifts=(-shift_size[0], -shift_size[1], -shift_size[2]), dims=(1, 2, 3))
332
+ attn_mask = mask_matrix
333
+ else:
334
+ shifted_x = x
335
+ attn_mask = None
336
+ # partition windows
337
+ x_windows = window_partition(shifted_x, window_size) # B*nW, Wd*Wh*Ww, C
338
+ # W-MSA/SW-MSA
339
+ attn_windows = self.attn(x_windows, mask=attn_mask) # B*nW, Wd*Wh*Ww, C
340
+ # merge windows
341
+ attn_windows = attn_windows.view(-1, *(window_size+(C,)))
342
+ shifted_x = window_reverse(attn_windows, window_size, B, Dp, Hp, Wp) # B D' H' W' C
343
+ # reverse cyclic shift
344
+ if any(i > 0 for i in shift_size):
345
+ x = torch.roll(shifted_x, shifts=(shift_size[0], shift_size[1], shift_size[2]), dims=(1, 2, 3))
346
+ else:
347
+ x = shifted_x
348
+
349
+ if pad_d1 >0 or pad_r > 0 or pad_b > 0:
350
+ x = x[:, :D, :H, :W, :].contiguous()
351
+ return x
352
+
353
+ def forward_part2(self, x):
354
+ return self.drop_path(self.mlp(self.norm2(x)))
355
+
356
+ def forward(self, x, mask_matrix):
357
+ """ Forward function.
358
+
359
+ Args:
360
+ x: Input feature, tensor size (B, D, H, W, C).
361
+ mask_matrix: Attention mask for cyclic shift.
362
+ """
363
+
364
+ shortcut = x
365
+ if self.use_checkpoint:
366
+ x = checkpoint.checkpoint(self.forward_part1, x, mask_matrix)
367
+ else:
368
+ x = self.forward_part1(x, mask_matrix)
369
+ x = shortcut + self.drop_path(x)
370
+
371
+ if self.use_checkpoint:
372
+ x = x + checkpoint.checkpoint(self.forward_part2, x)
373
+ else:
374
+ x = x + self.forward_part2(x)
375
+
376
+ return x
377
+
378
+
379
+ class PatchMerging(nn.Module):
380
+ """ Patch Merging Layer
381
+
382
+ Args:
383
+ dim (int): Number of input channels.
384
+ norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
385
+ """
386
+ def __init__(self, dim, norm_layer=nn.LayerNorm):
387
+ super().__init__()
388
+ self.dim = dim
389
+ self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
390
+ self.norm = norm_layer(4 * dim)
391
+
392
+ def forward(self, x):
393
+ """ Forward function.
394
+
395
+ Args:
396
+ x: Input feature, tensor size (B, D, H, W, C).
397
+ """
398
+ B, D, H, W, C = x.shape
399
+
400
+ # padding
401
+ pad_input = (H % 2 == 1) or (W % 2 == 1)
402
+ if pad_input:
403
+ x = F.pad(x, (0, 0, 0, W % 2, 0, H % 2))
404
+
405
+ x0 = x[:, :, 0::2, 0::2, :] # B D H/2 W/2 C
406
+ x1 = x[:, :, 1::2, 0::2, :] # B D H/2 W/2 C
407
+ x2 = x[:, :, 0::2, 1::2, :] # B D H/2 W/2 C
408
+ x3 = x[:, :, 1::2, 1::2, :] # B D H/2 W/2 C
409
+ x = torch.cat([x0, x1, x2, x3], -1) # B D H/2 W/2 4*C
410
+
411
+ x = self.norm(x)
412
+ x = self.reduction(x)
413
+
414
+ return x
415
+
416
+
417
+ # cache each stage results
418
+ @lru_cache()
419
+ def compute_mask(D, H, W, window_size, shift_size, device):
420
+ img_mask = torch.zeros((1, D, H, W, 1), device=device) # 1 Dp Hp Wp 1
421
+ cnt = 0
422
+ for d in slice(-window_size[0]), slice(-window_size[0], -shift_size[0]), slice(-shift_size[0],None):
423
+ for h in slice(-window_size[1]), slice(-window_size[1], -shift_size[1]), slice(-shift_size[1],None):
424
+ for w in slice(-window_size[2]), slice(-window_size[2], -shift_size[2]), slice(-shift_size[2],None):
425
+ img_mask[:, d, h, w, :] = cnt
426
+ cnt += 1
427
+ mask_windows = window_partition(img_mask, window_size) # nW, ws[0]*ws[1]*ws[2], 1
428
+ mask_windows = mask_windows.squeeze(-1) # nW, ws[0]*ws[1]*ws[2]
429
+ attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
430
+ attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
431
+ return attn_mask
432
+
433
+ class RSTB3D(nn.Module):
434
+ """ A basic Swin Transformer layer for one stage.
435
+
436
+ Args:
437
+ dim (int): Number of feature channels
438
+ depth (int): Depths of this stage.
439
+ num_heads (int): Number of attention head.
440
+ window_size (tuple[int]): Local window size. Default: (1,7,7).
441
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4.
442
+ qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
443
+ qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
444
+ drop (float, optional): Dropout rate. Default: 0.0
445
+ attn_drop (float, optional): Attention dropout rate. Default: 0.0
446
+ drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
447
+ norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
448
+ downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
449
+ """
450
+
451
+ def __init__(self,
452
+ dim,
453
+ depth,
454
+ num_heads,
455
+ window_size=(1,7,7),
456
+ mlp_ratio=4.,
457
+ qkv_bias=False,
458
+ qk_scale=None,
459
+ drop=0.,
460
+ attn_drop=0.,
461
+ drop_path=0.,
462
+ norm_layer=nn.LayerNorm,
463
+ downsample=None,
464
+ in_chans=1,
465
+ patch_norm=True,
466
+ patch_size=(3,4,4),
467
+ use_checkpoint=False):
468
+ super().__init__()
469
+ self.window_size = window_size
470
+ self.shift_size = tuple(i // 2 for i in window_size)
471
+ self.depth = depth
472
+ self.use_checkpoint = use_checkpoint
473
+
474
+ self.basic_layer = BasicLayer(
475
+ dim=dim,
476
+ depth=depth,
477
+ num_heads=num_heads,
478
+ window_size=window_size,
479
+ mlp_ratio=mlp_ratio,
480
+ qkv_bias=qkv_bias,
481
+ qk_scale=qk_scale,
482
+ drop=drop,
483
+ attn_drop=attn_drop,
484
+ drop_path=drop_path,
485
+ norm_layer=norm_layer,
486
+ # downsample=PatchMerging if i_layer<self.num_layers-1 else None,
487
+ downsample=None,
488
+ use_checkpoint=use_checkpoint)
489
+
490
+ self.resi_connection1 = nn.Conv2d(3, 64, 3, 1, 1)
491
+ self.resi_connection2 = ResidualGroup(num_feat=64,squeeze_factor=16,num_block=12)
492
+ self.resi_connection3 = nn.Conv2d(64, 3, 3, 1, 1)
493
+
494
+ # split image into non-overlapping patches
495
+ self.patch_embed = PatchEmbed3D(
496
+ patch_size=patch_size, in_chans=in_chans, embed_dim=dim,
497
+ norm_layer=norm_layer if patch_norm else None)
498
+
499
+ # split image into non-overlapping patches
500
+ self.patch_unembed = PatchUnEmbed3D(
501
+ patch_size=patch_size, in_chans=in_chans, embed_dim=dim,
502
+ norm_layer=norm_layer if patch_norm else None)
503
+
504
+ self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
505
+
506
+ def forward(self, x):
507
+ shortcut = x
508
+ x = self.basic_layer(x)
509
+ x = self.patch_unembed(x)
510
+ x = self.resi_connection1(x)
511
+ x = self.lrelu(x)
512
+ x = self.resi_connection2(x)
513
+ x = self.lrelu(x)
514
+ x = self.resi_connection3(x)
515
+ x = self.lrelu(x)
516
+ x = self.patch_embed(x)
517
+ x = x + shortcut
518
+ x = self.lrelu(x)
519
+ return x
520
+
521
+
522
+ class BasicLayer(nn.Module):
523
+ """ A basic Swin Transformer layer for one stage.
524
+
525
+ Args:
526
+ dim (int): Number of feature channels
527
+ depth (int): Depths of this stage.
528
+ num_heads (int): Number of attention head.
529
+ window_size (tuple[int]): Local window size. Default: (1,7,7).
530
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4.
531
+ qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
532
+ qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
533
+ drop (float, optional): Dropout rate. Default: 0.0
534
+ attn_drop (float, optional): Attention dropout rate. Default: 0.0
535
+ drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
536
+ norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
537
+ downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
538
+ """
539
+
540
+ def __init__(self,
541
+ dim,
542
+ depth,
543
+ num_heads,
544
+ window_size=(1,7,7),
545
+ mlp_ratio=4.,
546
+ qkv_bias=False,
547
+ qk_scale=None,
548
+ drop=0.,
549
+ attn_drop=0.,
550
+ drop_path=0.,
551
+ norm_layer=nn.LayerNorm,
552
+ downsample=None,
553
+ use_checkpoint=False):
554
+ super().__init__()
555
+ self.window_size = window_size
556
+ self.shift_size = tuple(i // 2 for i in window_size)
557
+ self.depth = depth
558
+ self.use_checkpoint = use_checkpoint
559
+
560
+ # build blocks
561
+ self.blocks = nn.ModuleList([
562
+ SwinTransformerBlock3D(
563
+ dim=dim,
564
+ num_heads=num_heads,
565
+ window_size=window_size,
566
+ shift_size=(0,0,0) if (i % 2 == 0) else self.shift_size,
567
+ mlp_ratio=mlp_ratio,
568
+ qkv_bias=qkv_bias,
569
+ qk_scale=qk_scale,
570
+ drop=drop,
571
+ attn_drop=attn_drop,
572
+ drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
573
+ norm_layer=norm_layer,
574
+ use_checkpoint=use_checkpoint,
575
+ )
576
+ for i in range(depth)])
577
+
578
+ self.downsample = downsample
579
+ if self.downsample is not None:
580
+ self.downsample = downsample(dim=dim, norm_layer=norm_layer)
581
+
582
+ def forward(self, x):
583
+ """ Forward function.
584
+
585
+ Args:
586
+ x: Input feature, tensor size (B, C, D, H, W).
587
+ """
588
+ # calculate attention mask for SW-MSA
589
+ B, C, D, H, W = x.shape
590
+ window_size, shift_size = get_window_size((D,H,W), self.window_size, self.shift_size)
591
+ x = rearrange(x, 'b c d h w -> b d h w c')
592
+ Dp = int(np.ceil(D / window_size[0])) * window_size[0]
593
+ Hp = int(np.ceil(H / window_size[1])) * window_size[1]
594
+ Wp = int(np.ceil(W / window_size[2])) * window_size[2]
595
+ attn_mask = compute_mask(Dp, Hp, Wp, window_size, shift_size, x.device)
596
+ for blk in self.blocks:
597
+ x = blk(x, attn_mask)
598
+ x = x.view(B, D, H, W, -1)
599
+
600
+ if self.downsample is not None:
601
+ x = self.downsample(x)
602
+ x = rearrange(x, 'b d h w c -> b c d h w')
603
+ return x
604
+
605
+
606
+ class PatchEmbed3D(nn.Module):
607
+ """ Video to Patch Embedding.
608
+
609
+ Args:
610
+ patch_size (int): Patch token size. Default: (2,4,4).
611
+ in_chans (int): Number of input video channels. Default: 3.
612
+ embed_dim (int): Number of linear projection output channels. Default: 96.
613
+ norm_layer (nn.Module, optional): Normalization layer. Default: None
614
+ """
615
+ def __init__(self, patch_size=(3,4,4), in_chans=3, embed_dim=96, norm_layer=None):
616
+ super().__init__()
617
+ self.patch_size = patch_size
618
+
619
+ #print('received patch size', patch_size)
620
+ self.in_chans = in_chans
621
+ self.embed_dim = embed_dim
622
+
623
+ self.proj = nn.Conv3d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
624
+ if norm_layer is not None:
625
+ self.norm = norm_layer(embed_dim)
626
+ else:
627
+ self.norm = None
628
+
629
+ def forward(self, x):
630
+ """Forward function."""
631
+ x = x.unsqueeze(1) # assuming gray scale video frames are encoded as channels, now separate
632
+
633
+ x = self.proj(x) # B C D Wh Ww
634
+ if self.norm is not None:
635
+ #print('ionside here with self.norm')
636
+ D, Wh, Ww = x.size(2), x.size(3), x.size(4)
637
+ x = x.flatten(2).transpose(1, 2)
638
+ x = self.norm(x)
639
+ x = x.transpose(1, 2).view(-1, self.embed_dim, D, Wh, Ww)
640
+
641
+ return x
642
+
643
+ class PatchUnEmbed3D(nn.Module):
644
+ def __init__(self, patch_size=(3,4,4), in_chans=3, embed_dim=96, norm_layer=nn.LayerNorm):
645
+ super().__init__()
646
+ self.patch_size = patch_size
647
+
648
+ self.in_chans = in_chans
649
+ self.embed_dim = embed_dim
650
+
651
+ unembed_dim = 1
652
+ self.unembed_dim = unembed_dim
653
+
654
+ self.proj = nn.ConvTranspose3d(embed_dim, unembed_dim, kernel_size=patch_size, stride=patch_size)
655
+ self.conv = nn.Conv2d(3*unembed_dim, 3, 3, 1, 1)
656
+
657
+ if norm_layer is not None:
658
+ self.norm = norm_layer(unembed_dim)
659
+ else:
660
+ self.norm = None
661
+
662
+ self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
663
+
664
+ def forward(self, x):
665
+
666
+ D, Wh, Ww = x.size(2), x.size(3), x.size(4)
667
+ # x = x.view(-1,self.embed_dim*D,Wh,Ww)
668
+ x = self.proj(x)
669
+
670
+ # if self.norm is not None:
671
+ # D, Wh, Ww = x.size(2), x.size(3), x.size(4)
672
+ # x = x.flatten(2).transpose(1, 2)
673
+ # x = self.norm(x)
674
+ # x = x.transpose(1, 2).view(-1, self.unembed_dim, D, Wh, Ww)
675
+
676
+
677
+
678
+ x = self.lrelu(x)
679
+ x = x.view(-1,3*D,4*Wh,4*Ww)
680
+ # x = x.flatten(start_dim=1,end_dim=2)
681
+ # x = x.view(-1,9,4*Wh,4*Ww) # 18 128 128
682
+ x = self.lrelu(self.conv(x)) # 64 128 128
683
+
684
+ return x
685
+
686
+ class Upsampler(nn.Module):
687
+ def __init__(self, patch_size=(3,4,4), in_chans=3, embed_dim=96, norm_layer=nn.LayerNorm):
688
+ super().__init__()
689
+ self.patch_size = patch_size
690
+
691
+ self.in_chans = in_chans
692
+ self.embed_dim = embed_dim
693
+
694
+ self.expand = nn.Conv2d(9, 20, 3, 1, 1)
695
+
696
+ self.shuffle = nn.PixelShuffle(2)
697
+ self.fusion = nn.Conv2d(20//4, 1, 3, 1, 1)
698
+
699
+ self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
700
+
701
+ def forward(self, x):
702
+
703
+ # x = x.view(-1,self.embed_dim*D,Wh,Ww)
704
+ x = self.lrelu(self.expand(x))
705
+ x = self.shuffle(x) # 16 256 256
706
+ x = self.lrelu(self.fusion(x))
707
+
708
+ return x
709
+
710
+
711
+ class SwinTransformer3D_RCAB(nn.Module):
712
+ """ Swin Transformer backbone.
713
+ A PyTorch impl of : `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows` -
714
+ https://arxiv.org/pdf/2103.14030
715
+
716
+ Args:
717
+ patch_size (int | tuple(int)): Patch size. Default: (4,4,4).
718
+ in_chans (int): Number of input image channels. Default: 3.
719
+ embed_dim (int): Number of linear projection output channels. Default: 96.
720
+ depths (tuple[int]): Depths of each Swin Transformer stage.
721
+ num_heads (tuple[int]): Number of attention head of each stage.
722
+ window_size (int): Window size. Default: 7.
723
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4.
724
+ qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: Truee
725
+ qk_scale (float): Override default qk scale of head_dim ** -0.5 if set.
726
+ drop_rate (float): Dropout rate.
727
+ attn_drop_rate (float): Attention dropout rate. Default: 0.
728
+ drop_path_rate (float): Stochastic depth rate. Default: 0.2.
729
+ norm_layer: Normalization layer. Default: nn.LayerNorm.
730
+ patch_norm (bool): If True, add normalization after patch embedding. Default: False.
731
+ frozen_stages (int): Stages to be frozen (stop grad and set eval mode).
732
+ -1 means not freezing any parameters.
733
+ """
734
+
735
+ def __init__(self,
736
+ opt,
737
+ patch_size=(4,4,4),
738
+ in_chans=1,
739
+ embed_dim=96,
740
+ depths=[2, 2, 6, 2],
741
+ num_heads=[3, 6, 12, 24],
742
+ window_size=(2,7,7),
743
+ mlp_ratio=4.,
744
+ qkv_bias=True,
745
+ qk_scale=None,
746
+ drop_rate=0.,
747
+ attn_drop_rate=0.,
748
+ drop_path_rate=0.2,
749
+ norm_layer=nn.LayerNorm,
750
+ patch_norm=True,
751
+ upscale=2,
752
+ frozen_stages=-1,
753
+ use_checkpoint=False,
754
+ vis=False,
755
+ **kwargs):
756
+ super().__init__()
757
+
758
+ self.num_layers = len(depths)
759
+ self.embed_dim = embed_dim
760
+ self.patch_norm = patch_norm
761
+ self.window_size = window_size
762
+ self.patch_size = patch_size
763
+
764
+ # split image into non-overlapping patches
765
+ self.patch_embed = PatchEmbed3D(
766
+ patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim,
767
+ norm_layer=norm_layer if self.patch_norm else None)
768
+
769
+ # split image into non-overlapping patches
770
+ self.patch_unembed = PatchUnEmbed3D(
771
+ patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim,
772
+ norm_layer=norm_layer if self.patch_norm else None)
773
+
774
+ self.pos_drop = nn.Dropout(p=drop_rate)
775
+
776
+ # stochastic depth
777
+ dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule
778
+
779
+ # build layers
780
+ self.layers = nn.ModuleList()
781
+ for i_layer in range(self.num_layers):
782
+ layer = RSTB3D(
783
+ dim=embed_dim,
784
+ depth=depths[i_layer],
785
+ num_heads=num_heads[i_layer],
786
+ window_size=window_size,
787
+ mlp_ratio=mlp_ratio,
788
+ qkv_bias=qkv_bias,
789
+ qk_scale=qk_scale,
790
+ drop=drop_rate,
791
+ attn_drop=attn_drop_rate,
792
+ drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])],
793
+ norm_layer=norm_layer,
794
+ # downsample=PatchMerging if i_layer<self.num_layers-1 else None,
795
+ downsample=None,
796
+ in_chans=in_chans,
797
+ patch_size=patch_size,
798
+ patch_norm=patch_norm,
799
+ use_checkpoint=use_checkpoint)
800
+ self.layers.append(layer)
801
+
802
+ self.num_features = int(embed_dim)
803
+
804
+ # add a norm layer for each output
805
+ self.norm = norm_layer(self.num_features)
806
+
807
+ self.upsampler = Upsampler(embed_dim=embed_dim)
808
+ self.task = opt.task
809
+
810
+ self.segmentation_decode = nn.Conv2d(3, 4, 1)
811
+
812
+
813
+ def init_weights(self, pretrained=None):
814
+ """Initialize the weights in backbone.
815
+
816
+ Args:
817
+ pretrained (str, optional): Path to pre-trained weights.
818
+ Defaults to None.
819
+ """
820
+ def _init_weights(m):
821
+ if isinstance(m, nn.Linear):
822
+ trunc_normal_(m.weight, std=.02)
823
+ if isinstance(m, nn.Linear) and m.bias is not None:
824
+ nn.init.constant_(m.bias, 0)
825
+ elif isinstance(m, nn.LayerNorm):
826
+ nn.init.constant_(m.bias, 0)
827
+ nn.init.constant_(m.weight, 1.0)
828
+
829
+ if pretrained:
830
+ self.pretrained = pretrained
831
+ # if isinstance(self.pretrained, str):
832
+ # self.apply(_init_weights)
833
+ # logger = get_root_logger()
834
+ # logger.info(f'load model from: {self.pretrained}')
835
+
836
+ # if self.pretrained2d:
837
+ # Inflate 2D model into 3D model.
838
+ # self.inflate_weights(logger)
839
+ # else:
840
+ # Directly load 3D model.
841
+ # load_checkpoint(self, self.pretrained, strict=False, logger=logger)
842
+ elif self.pretrained is None:
843
+ self.apply(_init_weights)
844
+ else:
845
+ raise TypeError('pretrained must be a str or None')
846
+
847
+ def forward(self, x):
848
+ """Forward function."""
849
+
850
+ shortcut = x
851
+ x = self.patch_embed(x)
852
+
853
+ x = self.pos_drop(x)
854
+ #print('after pos drop',x.shape)
855
+
856
+ for layer in self.layers:
857
+ x = layer(x.contiguous())
858
+
859
+
860
+ x = rearrange(x, 'n c d h w -> n d h w c')
861
+ #print('after rearrange',x.shape)
862
+ x = self.norm(x)
863
+ #print('after norm',x.shape)
864
+ x = rearrange(x, 'n d h w c -> n c d h w')
865
+ #print('after rearrange',x.shape)
866
+
867
+ x = self.patch_unembed(x)
868
+
869
+ x = x + shortcut
870
+
871
+ if self.task == 'segment':
872
+ x = self.segmentation_decode(x)
873
+
874
+ else:
875
+ x = self.upsampler(x)
876
+
877
+ return x
878
+
879
+
880
+
881
+
archs/swinir_rcab.py ADDED
@@ -0,0 +1,1296 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Modified from https://github.com/JingyunLiang/SwinIR
2
+ # SwinIR: Image Restoration Using Swin Transformer, https://arxiv.org/abs/2108.10257
3
+ # Originally Written by Ze Liu, Modified by Jingyun Liang.
4
+
5
+ import collections.abc
6
+ import math
7
+ import torch
8
+ import torch.nn as nn
9
+ import torch.utils.checkpoint as checkpoint
10
+ from itertools import repeat
11
+
12
+ # from self_attention_cv import AxialAttentionBlock
13
+
14
+ from functools import reduce, lru_cache
15
+ from operator import mul
16
+ from einops import rearrange
17
+ import sys
18
+
19
+
20
+ def make_layer(basic_block, num_basic_block, **kwarg):
21
+ """Make layers by stacking the same blocks.
22
+
23
+ Args:
24
+ basic_block (nn.module): nn.module class for basic block.
25
+ num_basic_block (int): number of blocks.
26
+
27
+ Returns:
28
+ nn.Sequential: Stacked blocks in nn.Sequential.
29
+ """
30
+ layers = []
31
+ for _ in range(num_basic_block):
32
+ layers.append(basic_block(**kwarg))
33
+ return nn.Sequential(*layers)
34
+
35
+ class Upsample(nn.Sequential):
36
+ """Upsample module.
37
+
38
+ Args:
39
+ scale (int): Scale factor. Supported scales: 2^n and 3.
40
+ num_feat (int): Channel number of intermediate features.
41
+ """
42
+
43
+ def __init__(self, scale, num_feat):
44
+ m = []
45
+ if (scale & (scale - 1)) == 0: # scale = 2^n
46
+ for _ in range(int(math.log(scale, 2))):
47
+ m.append(nn.Conv2d(num_feat, 4 * num_feat, 3, 1, 1))
48
+ m.append(nn.PixelShuffle(2))
49
+ elif scale == 3:
50
+ m.append(nn.Conv2d(num_feat, 9 * num_feat, 3, 1, 1))
51
+ m.append(nn.PixelShuffle(3))
52
+ else:
53
+ raise ValueError(f'scale {scale} is not supported. ' 'Supported scales: 2^n and 3.')
54
+ super(Upsample, self).__init__(*m)
55
+
56
+
57
+ # From PyTorch
58
+ def _ntuple(n):
59
+
60
+ def parse(x):
61
+ if isinstance(x, collections.abc.Iterable):
62
+ return x
63
+ return tuple(repeat(x, n))
64
+
65
+ return parse
66
+ to_2tuple = _ntuple(2)
67
+
68
+ def _no_grad_trunc_normal_(tensor, mean, std, a, b):
69
+ # From: https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/layers/weight_init.py
70
+ # Cut & paste from PyTorch official master until it's in a few official releases - RW
71
+ # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
72
+ def norm_cdf(x):
73
+ # Computes standard normal cumulative distribution function
74
+ return (1. + math.erf(x / math.sqrt(2.))) / 2.
75
+
76
+ if (mean < a - 2 * std) or (mean > b + 2 * std):
77
+ warnings.warn(
78
+ 'mean is more than 2 std from [a, b] in nn.init.trunc_normal_. '
79
+ 'The distribution of values may be incorrect.',
80
+ stacklevel=2)
81
+
82
+ with torch.no_grad():
83
+ # Values are generated by using a truncated uniform distribution and
84
+ # then using the inverse CDF for the normal distribution.
85
+ # Get upper and lower cdf values
86
+ low = norm_cdf((a - mean) / std)
87
+ up = norm_cdf((b - mean) / std)
88
+
89
+ # Uniformly fill tensor with values from [low, up], then translate to
90
+ # [2l-1, 2u-1].
91
+ tensor.uniform_(2 * low - 1, 2 * up - 1)
92
+
93
+ # Use inverse cdf transform for normal distribution to get truncated
94
+ # standard normal
95
+ tensor.erfinv_()
96
+
97
+ # Transform to proper mean, std
98
+ tensor.mul_(std * math.sqrt(2.))
99
+ tensor.add_(mean)
100
+
101
+ # Clamp to ensure it's in the proper range
102
+ tensor.clamp_(min=a, max=b)
103
+ return tensor
104
+
105
+
106
+ def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.):
107
+ r"""Fills the input Tensor with values drawn from a truncated
108
+ normal distribution.
109
+
110
+ From: https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/layers/weight_init.py
111
+
112
+ The values are effectively drawn from the
113
+ normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)`
114
+ with values outside :math:`[a, b]` redrawn until they are within
115
+ the bounds. The method used for generating the random values works
116
+ best when :math:`a \leq \text{mean} \leq b`.
117
+
118
+ Args:
119
+ tensor: an n-dimensional `torch.Tensor`
120
+ mean: the mean of the normal distribution
121
+ std: the standard deviation of the normal distribution
122
+ a: the minimum cutoff value
123
+ b: the maximum cutoff value
124
+
125
+ Examples:
126
+ >>> w = torch.empty(3, 5)
127
+ >>> nn.init.trunc_normal_(w)
128
+ """
129
+ return _no_grad_trunc_normal_(tensor, mean, std, a, b)
130
+
131
+ class ChannelAttention(nn.Module):
132
+ """Channel attention used in RCAN.
133
+
134
+ Args:
135
+ num_feat (int): Channel number of intermediate features.
136
+ squeeze_factor (int): Channel squeeze factor. Default: 16.
137
+ """
138
+
139
+ def __init__(self, num_feat, squeeze_factor=16):
140
+ super(ChannelAttention, self).__init__()
141
+ self.attention = nn.Sequential(
142
+ nn.AdaptiveAvgPool2d(1), nn.Conv2d(num_feat, num_feat // squeeze_factor, 1, padding=0),
143
+ nn.ReLU(inplace=True), nn.Conv2d(num_feat // squeeze_factor, num_feat, 1, padding=0), nn.Sigmoid())
144
+
145
+ def forward(self, x):
146
+ y = self.attention(x)
147
+ return x * y
148
+
149
+
150
+ class RCAB(nn.Module):
151
+ """Residual Channel Attention Block (RCAB) used in RCAN.
152
+
153
+ Args:
154
+ num_feat (int): Channel number of intermediate features.
155
+ squeeze_factor (int): Channel squeeze factor. Default: 16.
156
+ res_scale (float): Scale the residual. Default: 1.
157
+ """
158
+
159
+ def __init__(self, num_feat, squeeze_factor=16, res_scale=1):
160
+ super(RCAB, self).__init__()
161
+ self.res_scale = res_scale
162
+
163
+ self.rcab = nn.Sequential(
164
+ nn.Conv2d(num_feat, num_feat, 3, 1, 1), nn.ReLU(True), nn.Conv2d(num_feat, num_feat, 3, 1, 1),
165
+ ChannelAttention(num_feat, squeeze_factor))
166
+
167
+ def forward(self, x):
168
+ res = self.rcab(x) * self.res_scale
169
+ return res + x
170
+
171
+
172
+ class ResidualGroup(nn.Module):
173
+ """Residual Group of RCAB.
174
+
175
+ Args:
176
+ num_feat (int): Channel number of intermediate features.
177
+ num_block (int): Block number in the body network.
178
+ squeeze_factor (int): Channel squeeze factor. Default: 16.
179
+ res_scale (float): Scale the residual. Default: 1.
180
+ """
181
+
182
+ def __init__(self, num_feat, num_block, squeeze_factor=16, res_scale=1):
183
+ super(ResidualGroup, self).__init__()
184
+
185
+ self.residual_group = make_layer(
186
+ RCAB, num_block, num_feat=num_feat, squeeze_factor=squeeze_factor, res_scale=res_scale)
187
+ self.conv = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
188
+
189
+ def forward(self, x):
190
+ res = self.conv(self.residual_group(x))
191
+ return res + x
192
+
193
+
194
+
195
+
196
+ def drop_path(x, drop_prob: float = 0., training: bool = False):
197
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
198
+
199
+ From: https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/layers/drop.py
200
+ """
201
+ if drop_prob == 0. or not training:
202
+ return x
203
+ keep_prob = 1 - drop_prob
204
+ shape = (x.shape[0], ) + (1, ) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
205
+ random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
206
+ random_tensor.floor_() # binarize
207
+ output = x.div(keep_prob) * random_tensor
208
+ return output
209
+
210
+
211
+ class DropPath(nn.Module):
212
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
213
+
214
+ From: https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/layers/drop.py
215
+ """
216
+
217
+ def __init__(self, drop_prob=None):
218
+ super(DropPath, self).__init__()
219
+ self.drop_prob = drop_prob
220
+
221
+ def forward(self, x):
222
+ return drop_path(x, self.drop_prob, self.training)
223
+
224
+
225
+ class Mlp(nn.Module):
226
+
227
+ def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
228
+ super().__init__()
229
+ out_features = out_features or in_features
230
+ hidden_features = hidden_features or in_features
231
+ self.fc1 = nn.Linear(in_features, hidden_features)
232
+ self.act = act_layer()
233
+ self.fc2 = nn.Linear(hidden_features, out_features)
234
+ self.drop = nn.Dropout(drop)
235
+
236
+ def forward(self, x):
237
+ x = self.fc1(x)
238
+ x = self.act(x)
239
+ x = self.drop(x)
240
+ x = self.fc2(x)
241
+ x = self.drop(x)
242
+ return x
243
+
244
+
245
+ def window_partition(x, window_size):
246
+ """
247
+ Args:
248
+ x: (b, h, w, c)
249
+ window_size (int): window size
250
+
251
+ Returns:
252
+ windows: (num_windows*b, window_size, window_size, c)
253
+ """
254
+ b, h, w, c = x.shape
255
+ x = x.view(b, h // window_size, window_size, w // window_size, window_size, c)
256
+ windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, c)
257
+ return windows
258
+
259
+
260
+ def window_reverse(windows, window_size, h, w):
261
+ """
262
+ Args:
263
+ windows: (num_windows*b, window_size, window_size, c)
264
+ window_size (int): Window size
265
+ h (int): Height of image
266
+ w (int): Width of image
267
+
268
+ Returns:
269
+ x: (b, h, w, c)
270
+ """
271
+ b = int(windows.shape[0] / (h * w / window_size / window_size))
272
+ x = windows.view(b, h // window_size, w // window_size, window_size, window_size, -1)
273
+ x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(b, h, w, -1)
274
+ return x
275
+
276
+
277
+ class WindowAttention(nn.Module):
278
+ r""" Window based multi-head self attention (W-MSA) module with relative position bias.
279
+ It supports both of shifted and non-shifted window.
280
+
281
+ Args:
282
+ dim (int): Number of input channels.
283
+ window_size (tuple[int]): The height and width of the window.
284
+ num_heads (int): Number of attention heads.
285
+ qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
286
+ qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set
287
+ attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0
288
+ proj_drop (float, optional): Dropout ratio of output. Default: 0.0
289
+ """
290
+
291
+ def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.):
292
+
293
+ super().__init__()
294
+ self.dim = dim
295
+ self.window_size = window_size # Wh, Ww
296
+ self.num_heads = num_heads
297
+ head_dim = dim // num_heads
298
+ self.scale = qk_scale or head_dim**-0.5
299
+
300
+ # define a parameter table of relative position bias
301
+ self.relative_position_bias_table = nn.Parameter(
302
+ torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)) # 2*Wh-1 * 2*Ww-1, nH
303
+
304
+ # get pair-wise relative position index for each token inside the window
305
+ coords_h = torch.arange(self.window_size[0])
306
+ coords_w = torch.arange(self.window_size[1])
307
+ coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
308
+ coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
309
+ relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
310
+ relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
311
+ relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0
312
+ relative_coords[:, :, 1] += self.window_size[1] - 1
313
+ relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
314
+ relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
315
+ self.register_buffer('relative_position_index', relative_position_index)
316
+
317
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
318
+ self.attn_drop = nn.Dropout(attn_drop)
319
+ self.proj = nn.Linear(dim, dim)
320
+
321
+ self.proj_drop = nn.Dropout(proj_drop)
322
+
323
+ trunc_normal_(self.relative_position_bias_table, std=.02)
324
+ self.softmax = nn.Softmax(dim=-1)
325
+
326
+ def forward(self, x, mask=None):
327
+ """
328
+ Args:
329
+ x: input features with shape of (num_windows*b, n, c)
330
+ mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None
331
+ """
332
+ b_, n, c = x.shape
333
+ qkv = self.qkv(x).reshape(b_, n, 3, self.num_heads, c // self.num_heads).permute(2, 0, 3, 1, 4)
334
+ q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
335
+
336
+ q = q * self.scale
337
+ attn = (q @ k.transpose(-2, -1))
338
+
339
+ relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
340
+ self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH
341
+ relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
342
+ attn = attn + relative_position_bias.unsqueeze(0)
343
+
344
+ if mask is not None:
345
+ nw = mask.shape[0]
346
+ attn = attn.view(b_ // nw, nw, self.num_heads, n, n) + mask.unsqueeze(1).unsqueeze(0)
347
+ attn = attn.view(-1, self.num_heads, n, n)
348
+ attn = self.softmax(attn)
349
+ else:
350
+ attn = self.softmax(attn)
351
+
352
+ attn = self.attn_drop(attn)
353
+
354
+ x = (attn @ v).transpose(1, 2).reshape(b_, n, c)
355
+ x = self.proj(x)
356
+ x = self.proj_drop(x)
357
+ return x
358
+
359
+ def extra_repr(self) -> str:
360
+ return f'dim={self.dim}, window_size={self.window_size}, num_heads={self.num_heads}'
361
+
362
+ def flops(self, n):
363
+ # calculate flops for 1 window with token length of n
364
+ flops = 0
365
+ # qkv = self.qkv(x)
366
+ flops += n * self.dim * 3 * self.dim
367
+ # attn = (q @ k.transpose(-2, -1))
368
+ flops += self.num_heads * n * (self.dim // self.num_heads) * n
369
+ # x = (attn @ v)
370
+ flops += self.num_heads * n * n * (self.dim // self.num_heads)
371
+ # x = self.proj(x)
372
+ flops += n * self.dim * self.dim
373
+ return flops
374
+
375
+
376
+ class SwinTransformerBlock(nn.Module):
377
+ r""" Swin Transformer Block.
378
+
379
+ Args:
380
+ dim (int): Number of input channels.
381
+ input_resolution (tuple[int]): Input resolution.
382
+ num_heads (int): Number of attention heads.
383
+ window_size (int): Window size.
384
+ shift_size (int): Shift size for SW-MSA.
385
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
386
+ qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
387
+ qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
388
+ drop (float, optional): Dropout rate. Default: 0.0
389
+ attn_drop (float, optional): Attention dropout rate. Default: 0.0
390
+ drop_path (float, optional): Stochastic depth rate. Default: 0.0
391
+ act_layer (nn.Module, optional): Activation layer. Default: nn.GELU
392
+ norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
393
+ """
394
+
395
+ def __init__(self,
396
+ dim,
397
+ input_resolution,
398
+ num_heads,
399
+ window_size=7,
400
+ shift_size=0,
401
+ mlp_ratio=4.,
402
+ qkv_bias=True,
403
+ qk_scale=None,
404
+ drop=0.,
405
+ attn_drop=0.,
406
+ drop_path=0.,
407
+ act_layer=nn.GELU,
408
+ norm_layer=nn.LayerNorm):
409
+ super().__init__()
410
+ self.dim = dim
411
+ self.input_resolution = input_resolution
412
+ self.num_heads = num_heads
413
+ self.window_size = window_size
414
+ self.shift_size = shift_size
415
+ self.mlp_ratio = mlp_ratio
416
+ if min(self.input_resolution) <= self.window_size:
417
+ # if window size is larger than input resolution, we don't partition windows
418
+ self.shift_size = 0
419
+ self.window_size = min(self.input_resolution)
420
+ assert 0 <= self.shift_size < self.window_size, 'shift_size must in 0-window_size'
421
+
422
+ self.norm1 = norm_layer(dim)
423
+ self.attn = WindowAttention(
424
+ dim,
425
+ window_size=to_2tuple(self.window_size),
426
+ num_heads=num_heads,
427
+ qkv_bias=qkv_bias,
428
+ qk_scale=qk_scale,
429
+ attn_drop=attn_drop,
430
+ proj_drop=drop)
431
+
432
+ self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
433
+ self.norm2 = norm_layer(dim)
434
+ mlp_hidden_dim = int(dim * mlp_ratio)
435
+ self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
436
+
437
+ if self.shift_size > 0:
438
+ attn_mask = self.calculate_mask(self.input_resolution)
439
+ else:
440
+ attn_mask = None
441
+
442
+ self.register_buffer('attn_mask', attn_mask)
443
+
444
+ def calculate_mask(self, x_size):
445
+ # calculate attention mask for SW-MSA
446
+ h, w = x_size
447
+ img_mask = torch.zeros((1, h, w, 1)) # 1 h w 1
448
+ h_slices = (slice(0, -self.window_size), slice(-self.window_size,
449
+ -self.shift_size), slice(-self.shift_size, None))
450
+ w_slices = (slice(0, -self.window_size), slice(-self.window_size,
451
+ -self.shift_size), slice(-self.shift_size, None))
452
+ cnt = 0
453
+ for h in h_slices:
454
+ for w in w_slices:
455
+ img_mask[:, h, w, :] = cnt
456
+ cnt += 1
457
+
458
+ mask_windows = window_partition(img_mask, self.window_size) # nw, window_size, window_size, 1
459
+ mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
460
+ attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
461
+ attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
462
+
463
+ return attn_mask
464
+
465
+ def forward(self, x, x_size):
466
+ h, w = x_size
467
+ b, _, c = x.shape
468
+ # assert seq_len == h * w, "input feature has wrong size"
469
+
470
+ shortcut = x
471
+ x = self.norm1(x)
472
+ x = x.view(b, h, w, c)
473
+
474
+ # cyclic shift
475
+ if self.shift_size > 0:
476
+ shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
477
+ else:
478
+ shifted_x = x
479
+
480
+ # partition windows
481
+ x_windows = window_partition(shifted_x, self.window_size) # nw*b, window_size, window_size, c
482
+ x_windows = x_windows.view(-1, self.window_size * self.window_size, c) # nw*b, window_size*window_size, c
483
+
484
+ # W-MSA/SW-MSA (to be compatible for testing on images whose shapes are the multiple of window size
485
+ if self.input_resolution == x_size:
486
+ attn_windows = self.attn(x_windows, mask=self.attn_mask) # nw*b, window_size*window_size, c
487
+ else:
488
+ attn_windows = self.attn(x_windows, mask=self.calculate_mask(x_size).to(x.device))
489
+
490
+ # merge windows
491
+ attn_windows = attn_windows.view(-1, self.window_size, self.window_size, c)
492
+ shifted_x = window_reverse(attn_windows, self.window_size, h, w) # b h' w' c
493
+
494
+ # reverse cyclic shift
495
+ if self.shift_size > 0:
496
+ x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
497
+ else:
498
+ x = shifted_x
499
+ x = x.view(b, h * w, c)
500
+
501
+ # FFN
502
+ x = shortcut + self.drop_path(x)
503
+ x = x + self.drop_path(self.mlp(self.norm2(x)))
504
+
505
+ return x
506
+
507
+ def extra_repr(self) -> str:
508
+ return (f'dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, '
509
+ f'window_size={self.window_size}, shift_size={self.shift_size}, mlp_ratio={self.mlp_ratio}')
510
+
511
+ def flops(self):
512
+ flops = 0
513
+ h, w = self.input_resolution
514
+ # norm1
515
+ flops += self.dim * h * w
516
+ # W-MSA/SW-MSA
517
+ nw = h * w / self.window_size / self.window_size
518
+ flops += nw * self.attn.flops(self.window_size * self.window_size)
519
+ # mlp
520
+ flops += 2 * h * w * self.dim * self.dim * self.mlp_ratio
521
+ # norm2
522
+ flops += self.dim * h * w
523
+ return flops
524
+
525
+
526
+ class PatchMerging(nn.Module):
527
+ r""" Patch Merging Layer.
528
+
529
+ Args:
530
+ input_resolution (tuple[int]): Resolution of input feature.
531
+ dim (int): Number of input channels.
532
+ norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
533
+ """
534
+
535
+ def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm):
536
+ super().__init__()
537
+ self.input_resolution = input_resolution
538
+ self.dim = dim
539
+ self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
540
+ self.norm = norm_layer(4 * dim)
541
+
542
+ def forward(self, x):
543
+ """
544
+ x: b, h*w, c
545
+ """
546
+ h, w = self.input_resolution
547
+ b, seq_len, c = x.shape
548
+ assert seq_len == h * w, 'input feature has wrong size'
549
+ assert h % 2 == 0 and w % 2 == 0, f'x size ({h}*{w}) are not even.'
550
+
551
+ x = x.view(b, h, w, c)
552
+
553
+ x0 = x[:, 0::2, 0::2, :] # b h/2 w/2 c
554
+ x1 = x[:, 1::2, 0::2, :] # b h/2 w/2 c
555
+ x2 = x[:, 0::2, 1::2, :] # b h/2 w/2 c
556
+ x3 = x[:, 1::2, 1::2, :] # b h/2 w/2 c
557
+ x = torch.cat([x0, x1, x2, x3], -1) # b h/2 w/2 4*c
558
+ x = x.view(b, -1, 4 * c) # b h/2*w/2 4*c
559
+
560
+ x = self.norm(x)
561
+ x = self.reduction(x)
562
+
563
+ return x
564
+
565
+ def extra_repr(self) -> str:
566
+ return f'input_resolution={self.input_resolution}, dim={self.dim}'
567
+
568
+ def flops(self):
569
+ h, w = self.input_resolution
570
+ flops = h * w * self.dim
571
+ flops += (h // 2) * (w // 2) * 4 * self.dim * 2 * self.dim
572
+ return flops
573
+
574
+
575
+ class BasicLayer(nn.Module):
576
+ """ A basic Swin Transformer layer for one stage.
577
+
578
+ Args:
579
+ dim (int): Number of input channels.
580
+ input_resolution (tuple[int]): Input resolution.
581
+ depth (int): Number of blocks.
582
+ num_heads (int): Number of attention heads.
583
+ window_size (int): Local window size.
584
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
585
+ qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
586
+ qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
587
+ drop (float, optional): Dropout rate. Default: 0.0
588
+ attn_drop (float, optional): Attention dropout rate. Default: 0.0
589
+ drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
590
+ norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
591
+ downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
592
+ use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
593
+ """
594
+
595
+ def __init__(self,
596
+ dim,
597
+ input_resolution,
598
+ depth,
599
+ num_heads,
600
+ window_size,
601
+ mlp_ratio=4.,
602
+ qkv_bias=True,
603
+ qk_scale=None,
604
+ drop=0.,
605
+ attn_drop=0.,
606
+ drop_path=0.,
607
+ norm_layer=nn.LayerNorm,
608
+ downsample=None,
609
+ use_checkpoint=False):
610
+
611
+ super().__init__()
612
+ self.dim = dim
613
+ self.input_resolution = input_resolution
614
+ self.depth = depth
615
+ self.use_checkpoint = use_checkpoint
616
+
617
+ # build blocks
618
+ self.blocks = nn.ModuleList([
619
+ SwinTransformerBlock(
620
+ dim=dim,
621
+ input_resolution=input_resolution,
622
+ num_heads=num_heads,
623
+ window_size=window_size,
624
+ shift_size=0 if (i % 2 == 0) else window_size // 2,
625
+ mlp_ratio=mlp_ratio,
626
+ qkv_bias=qkv_bias,
627
+ qk_scale=qk_scale,
628
+ drop=drop,
629
+ attn_drop=attn_drop,
630
+ drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
631
+ norm_layer=norm_layer) for i in range(depth)
632
+ ])
633
+
634
+ # patch merging layer
635
+ if downsample is not None:
636
+ self.downsample = downsample(input_resolution, dim=dim, norm_layer=norm_layer)
637
+ else:
638
+ self.downsample = None
639
+
640
+ def forward(self, x, x_size):
641
+ for blk in self.blocks:
642
+ if self.use_checkpoint:
643
+ x = checkpoint.checkpoint(blk, x)
644
+ else:
645
+ x = blk(x, x_size)
646
+ if self.downsample is not None:
647
+ x = self.downsample(x)
648
+ return x
649
+
650
+ def extra_repr(self) -> str:
651
+ return f'dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}'
652
+
653
+ def flops(self):
654
+ flops = 0
655
+ for blk in self.blocks:
656
+ flops += blk.flops()
657
+ if self.downsample is not None:
658
+ flops += self.downsample.flops()
659
+ return flops
660
+
661
+
662
+ class RSTB(nn.Module):
663
+ """Residual Swin Transformer Block (RSTB).
664
+
665
+ Args:
666
+ dim (int): Number of input channels.
667
+ input_resolution (tuple[int]): Input resolution.
668
+ depth (int): Number of blocks.
669
+ num_heads (int): Number of attention heads.
670
+ window_size (int): Local window size.
671
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
672
+ qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
673
+ qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
674
+ drop (float, optional): Dropout rate. Default: 0.0
675
+ attn_drop (float, optional): Attention dropout rate. Default: 0.0
676
+ drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
677
+ norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
678
+ downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
679
+ use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
680
+ img_size: Input image size.
681
+ patch_size: Patch size.
682
+ resi_connection: The convolutional block before residual connection.
683
+ """
684
+
685
+ def __init__(self,
686
+ dim,
687
+ input_resolution,
688
+ depth,
689
+ num_heads,
690
+ window_size,
691
+ mlp_ratio=4.,
692
+ qkv_bias=True,
693
+ qk_scale=None,
694
+ drop=0.,
695
+ attn_drop=0.,
696
+ drop_path=0.,
697
+ norm_layer=nn.LayerNorm,
698
+ downsample=None,
699
+ use_checkpoint=False,
700
+ img_size=224,
701
+ patch_size=4,
702
+ use_rcab=True,
703
+ resi_connection='1conv'):
704
+ super(RSTB, self).__init__()
705
+
706
+ self.dim = dim
707
+ self.input_resolution = input_resolution
708
+
709
+ self.residual_group = BasicLayer(
710
+ dim=dim,
711
+ input_resolution=input_resolution,
712
+ depth=depth,
713
+ num_heads=num_heads,
714
+ window_size=window_size,
715
+ mlp_ratio=mlp_ratio,
716
+ qkv_bias=qkv_bias,
717
+ qk_scale=qk_scale,
718
+ drop=drop,
719
+ attn_drop=attn_drop,
720
+ drop_path=drop_path,
721
+ norm_layer=norm_layer,
722
+ downsample=downsample,
723
+ use_checkpoint=use_checkpoint)
724
+
725
+ # if resi_connection == '1conv':
726
+ # # ML-SIM v1 v2 v3 v4 v6 v7 v8
727
+ # self.conv = nn.Conv2d(dim, dim, 3, 1, 1)
728
+
729
+ # # ML-SIM v5
730
+ # # self.conv = nn.Sequential(
731
+ # # nn.PixelUnshuffle(2),
732
+ # # nn.Conv2d(4*dim, 4*dim, 3, 1, 1),
733
+ # # nn.PixelShuffle(2))
734
+
735
+ # elif resi_connection == '3conv':
736
+ # # to save parameters and memory
737
+ # self.conv = nn.Sequential(
738
+ # nn.Conv2d(dim, dim // 4, 3, 1, 1), nn.LeakyReLU(negative_slope=0.2, inplace=True),
739
+ # nn.Conv2d(dim // 4, dim // 4, 1, 1, 0), nn.LeakyReLU(negative_slope=0.2, inplace=True),
740
+ # nn.Conv2d(dim // 4, dim, 3, 1, 1))
741
+
742
+ self.use_rcab = use_rcab
743
+
744
+ self.resi_connection1 = nn.Conv2d(dim, dim, 3, 1, 1)
745
+ if self.use_rcab:
746
+ self.resi_connection2 = ResidualGroup(num_feat=dim,squeeze_factor=16,num_block=12)
747
+ self.resi_connection3 = nn.Conv2d(dim, dim, 3, 1, 1)
748
+
749
+ self.patch_embed = PatchEmbed(
750
+ img_size=img_size, patch_size=patch_size, in_chans=0, embed_dim=dim, norm_layer=None)
751
+
752
+ self.patch_unembed = PatchUnEmbed(
753
+ img_size=img_size, patch_size=patch_size, in_chans=0, embed_dim=dim, norm_layer=None)
754
+
755
+ def forward(self, x, x_size):
756
+ # return self.patch_embed(self.conv(self.patch_unembed(self.residual_group(x, x_size), x_size))) + x
757
+ shortcut = x
758
+ x = self.patch_unembed(self.residual_group(x, x_size), x_size)
759
+ x = self.resi_connection1(x)
760
+ if self.use_rcab:
761
+ x = self.resi_connection2(x)
762
+ x = self.resi_connection3(x)
763
+ x = self.patch_embed(x) + shortcut
764
+ return x
765
+
766
+ def flops(self):
767
+ flops = 0
768
+ flops += self.residual_group.flops()
769
+ h, w = self.input_resolution
770
+ flops += h * w * self.dim * self.dim * 9
771
+ flops += self.patch_embed.flops()
772
+ flops += self.patch_unembed.flops()
773
+
774
+ return flops
775
+
776
+
777
+ class PatchEmbed(nn.Module):
778
+ r""" Image to Patch Embedding
779
+
780
+ Args:
781
+ img_size (int): Image size. Default: 224.
782
+ patch_size (int): Patch token size. Default: 4.
783
+ in_chans (int): Number of input image channels. Default: 3.
784
+ embed_dim (int): Number of linear projection output channels. Default: 96.
785
+ norm_layer (nn.Module, optional): Normalization layer. Default: None
786
+ """
787
+
788
+ def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None):
789
+ super().__init__()
790
+ img_size = to_2tuple(img_size)
791
+ patch_size = to_2tuple(patch_size)
792
+ patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]]
793
+ self.img_size = img_size
794
+ self.patch_size = patch_size
795
+ self.patches_resolution = patches_resolution
796
+ self.num_patches = patches_resolution[0] * patches_resolution[1]
797
+
798
+ self.in_chans = in_chans
799
+ self.embed_dim = embed_dim
800
+
801
+ if norm_layer is not None:
802
+ self.norm = norm_layer(embed_dim)
803
+ else:
804
+ self.norm = None
805
+
806
+ def forward(self, x):
807
+ x = x.flatten(2).transpose(1, 2) # b Ph*Pw c
808
+ if self.norm is not None:
809
+ x = self.norm(x)
810
+ return x
811
+
812
+ def flops(self):
813
+ flops = 0
814
+ h, w = self.img_size
815
+ if self.norm is not None:
816
+ flops += h * w * self.embed_dim
817
+ return flops
818
+
819
+
820
+ class PatchUnEmbed(nn.Module):
821
+ r""" Image to Patch Unembedding
822
+
823
+ Args:
824
+ img_size (int): Image size. Default: 224.
825
+ patch_size (int): Patch token size. Default: 4.
826
+ in_chans (int): Number of input image channels. Default: 3.
827
+ embed_dim (int): Number of linear projection output channels. Default: 96.
828
+ norm_layer (nn.Module, optional): Normalization layer. Default: None
829
+ """
830
+
831
+ def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None):
832
+ super().__init__()
833
+ img_size = to_2tuple(img_size)
834
+ patch_size = to_2tuple(patch_size)
835
+ patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]]
836
+ self.img_size = img_size
837
+ self.patch_size = patch_size
838
+ self.patches_resolution = patches_resolution
839
+ self.num_patches = patches_resolution[0] * patches_resolution[1]
840
+
841
+ self.in_chans = in_chans
842
+ self.embed_dim = embed_dim
843
+
844
+ def forward(self, x, x_size):
845
+ x = x.transpose(1, 2).view(x.shape[0], self.embed_dim, x_size[0], x_size[1]) # b Ph*Pw c
846
+ return x
847
+
848
+ def flops(self):
849
+ flops = 0
850
+ return flops
851
+
852
+
853
+ class Upsample(nn.Sequential):
854
+ """Upsample module.
855
+
856
+ Args:
857
+ scale (int): Scale factor. Supported scales: 2^n and 3.
858
+ num_feat (int): Channel number of intermediate features.
859
+ """
860
+
861
+ def __init__(self, scale, num_feat):
862
+ m = []
863
+ if (scale & (scale - 1)) == 0: # scale = 2^n
864
+ for _ in range(int(math.log(scale, 2))):
865
+ m.append(nn.Conv2d(num_feat, 4 * num_feat, 3, 1, 1))
866
+ m.append(nn.PixelShuffle(2))
867
+ elif scale == 3:
868
+ m.append(nn.Conv2d(num_feat, 9 * num_feat, 3, 1, 1))
869
+ m.append(nn.PixelShuffle(3))
870
+ else:
871
+ raise ValueError(f'scale {scale} is not supported. ' 'Supported scales: 2^n and 3.')
872
+ super(Upsample, self).__init__(*m)
873
+
874
+
875
+ class UpsampleOneStep(nn.Sequential):
876
+ """UpsampleOneStep module (the difference with Upsample is that it always only has 1conv + 1pixelshuffle)
877
+ Used in lightweight SR to save parameters.
878
+
879
+ Args:
880
+ scale (int): Scale factor. Supported scales: 2^n and 3.
881
+ num_feat (int): Channel number of intermediate features.
882
+
883
+ """
884
+
885
+ def __init__(self, scale, num_feat, num_out_ch, input_resolution=None):
886
+ self.num_feat = num_feat
887
+ self.input_resolution = input_resolution
888
+ m = []
889
+ m.append(nn.Conv2d(num_feat, (scale**2) * num_out_ch, 3, 1, 1))
890
+ m.append(nn.PixelShuffle(scale))
891
+ super(UpsampleOneStep, self).__init__(*m)
892
+
893
+ def flops(self):
894
+ h, w = self.input_resolution
895
+ flops = h * w * self.num_feat * 3 * 9
896
+ return flops
897
+
898
+
899
+ class SwinIR_RCAB(nn.Module):
900
+ r""" SwinIR
901
+ A PyTorch impl of : `SwinIR: Image Restoration Using Swin Transformer`, based on Swin Transformer.
902
+
903
+ Args:
904
+ img_size (int | tuple(int)): Input image size. Default 64
905
+ patch_size (int | tuple(int)): Patch size. Default: 1
906
+ in_chans (int): Number of input image channels. Default: 3
907
+ embed_dim (int): Patch embedding dimension. Default: 96
908
+ depths (tuple(int)): Depth of each Swin Transformer layer.
909
+ num_heads (tuple(int)): Number of attention heads in different layers.
910
+ window_size (int): Window size. Default: 7
911
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4
912
+ qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True
913
+ qk_scale (float): Override default qk scale of head_dim ** -0.5 if set. Default: None
914
+ drop_rate (float): Dropout rate. Default: 0
915
+ attn_drop_rate (float): Attention dropout rate. Default: 0
916
+ drop_path_rate (float): Stochastic depth rate. Default: 0.1
917
+ norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm.
918
+ ape (bool): If True, add absolute position embedding to the patch embedding. Default: False
919
+ patch_norm (bool): If True, add normalization after patch embedding. Default: True
920
+ use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False
921
+ upscale: Upscale factor. 2/3/4/8 for image SR, 1 for denoising and compress artifact reduction
922
+ img_range: Image range. 1. or 255.
923
+ upsampler: The reconstruction reconstruction module. 'pixelshuffle'/'pixelshuffledirect'/'nearest+conv'/None
924
+ resi_connection: The convolutional block before residual connection. '1conv'/'3conv'
925
+ """
926
+
927
+ def __init__(self,
928
+ opt,
929
+ img_size=256,
930
+ patch_size=1,
931
+ in_chans=3,
932
+ embed_dim=64,
933
+ depths=(6, 6),
934
+ num_heads=(8,8),
935
+ window_size=4,
936
+ mlp_ratio=2.,
937
+ qkv_bias=True,
938
+ qk_scale=None,
939
+ drop_rate=0.,
940
+ attn_drop_rate=0.,
941
+ drop_path_rate=0.1,
942
+ norm_layer=nn.LayerNorm,
943
+ ape=False,
944
+ patch_norm=True,
945
+ use_checkpoint=False,
946
+ upscale=2,
947
+ img_range=1.,
948
+ upsampler='',
949
+ resi_connection='1conv',
950
+ pixelshuffleFactor=1,
951
+ use_rcab=True,
952
+ out_chans=1,
953
+ vis=False,
954
+ **kwargs):
955
+ super().__init__()
956
+ num_in_ch = in_chans
957
+ num_out_ch = out_chans#in_chans
958
+ num_feat = 64
959
+ self.img_range = img_range
960
+ if in_chans == 3:
961
+ rgb_mean = (0.4488, 0.4371, 0.4040)
962
+ self.mean = torch.Tensor(rgb_mean).view(1, 3, 1, 1)
963
+ else:
964
+ self.mean = torch.zeros(1, 1, 1, 1)
965
+ self.upscale = upscale
966
+ self.upsampler = upsampler
967
+ print('received ',depths,use_rcab)
968
+
969
+ # ------------------------- 1, shallow feature extraction ------------------------- #
970
+ # ML-SIM v1 v2 v3 v6
971
+ # self.conv_first = nn.Conv2d(num_in_ch, embed_dim, 3, 1, 1)
972
+
973
+
974
+ # ML-SIM v4 v5 v8
975
+ # print('received pixelshufflefactor',pixelshuffleFactor)
976
+ self.conv_first = nn.Conv2d(round(pixelshuffleFactor**2*num_in_ch), embed_dim, 3, 1, 1)
977
+ if pixelshuffleFactor >= 1:
978
+ self.pixelshuffle_encode = nn.PixelUnshuffle(pixelshuffleFactor)
979
+ self.pixelshuffle_decode = nn.PixelShuffle(pixelshuffleFactor)
980
+ else: # e.g. 1/3
981
+ self.pixelshuffle_encode = nn.PixelShuffle(round(1/pixelshuffleFactor))
982
+ self.pixelshuffle_decode = nn.PixelUnshuffle(round(1/pixelshuffleFactor))
983
+
984
+ # ML-SIM v7
985
+ # pixelshuffleFactor = kwargs['pixelshuffleFactor']
986
+ # self.conv_first = nn.Conv2d(round(3*pixelshuffleFactor**2*num_in_ch), embed_dim, 3, 1, 1)
987
+ # if pixelshuffleFactor > 1:
988
+ # self.pixelshuffle_encode = nn.PixelUnshuffle(pixelshuffleFactor)
989
+ # self.pixelshuffle_decode = nn.PixelShuffle(pixelshuffleFactor)
990
+ # else: # e.g. 1/3
991
+ # self.pixelshuffle_encode = nn.PixelShuffle(round(1/pixelshuffleFactor))
992
+ # self.pixelshuffle_decode = nn.PixelUnshuffle(round(1/pixelshuffleFactor))
993
+
994
+
995
+
996
+
997
+ # ------------------------- 2, deep feature extraction ------------------------- #
998
+ self.num_layers = len(depths)
999
+ self.embed_dim = embed_dim
1000
+ self.ape = ape
1001
+ self.patch_norm = patch_norm
1002
+ self.num_features = embed_dim
1003
+ self.mlp_ratio = mlp_ratio
1004
+
1005
+ # split image into non-overlapping patches
1006
+ self.patch_embed = PatchEmbed(
1007
+ img_size=img_size,
1008
+ patch_size=patch_size,
1009
+ in_chans=embed_dim,
1010
+ embed_dim=embed_dim,
1011
+ norm_layer=norm_layer if self.patch_norm else None)
1012
+ num_patches = self.patch_embed.num_patches
1013
+ patches_resolution = self.patch_embed.patches_resolution
1014
+ self.patches_resolution = patches_resolution
1015
+
1016
+ # merge non-overlapping patches into image
1017
+ self.patch_unembed = PatchUnEmbed(
1018
+ img_size=img_size,
1019
+ patch_size=patch_size,
1020
+ in_chans=embed_dim,
1021
+ embed_dim=embed_dim,
1022
+ norm_layer=norm_layer if self.patch_norm else None)
1023
+
1024
+ # absolute position embedding
1025
+ if self.ape:
1026
+ self.absolute_pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim))
1027
+ trunc_normal_(self.absolute_pos_embed, std=.02)
1028
+
1029
+ self.pos_drop = nn.Dropout(p=drop_rate)
1030
+
1031
+ # stochastic depth
1032
+ dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule
1033
+
1034
+ # build Residual Swin Transformer blocks (RSTB)
1035
+ self.layers = nn.ModuleList()
1036
+ for i_layer in range(self.num_layers):
1037
+ layer = RSTB(
1038
+ dim=embed_dim,
1039
+ input_resolution=(patches_resolution[0], patches_resolution[1]),
1040
+ depth=depths[i_layer],
1041
+ num_heads=num_heads[i_layer],
1042
+ window_size=window_size,
1043
+ mlp_ratio=self.mlp_ratio,
1044
+ qkv_bias=qkv_bias,
1045
+ qk_scale=qk_scale,
1046
+ drop=drop_rate,
1047
+ attn_drop=attn_drop_rate,
1048
+ drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])], # no impact on SR results
1049
+ norm_layer=norm_layer,
1050
+ downsample=None,
1051
+ use_checkpoint=use_checkpoint,
1052
+ img_size=img_size,
1053
+ patch_size=patch_size,
1054
+ use_rcab=use_rcab,
1055
+ resi_connection=resi_connection)
1056
+ self.layers.append(layer)
1057
+ self.norm = norm_layer(self.num_features)
1058
+
1059
+ # build the last conv layer in deep feature extraction
1060
+ if resi_connection == '1conv':
1061
+ self.conv_after_body = nn.Conv2d(embed_dim, embed_dim, 3, 1, 1)
1062
+ elif resi_connection == '3conv':
1063
+ # to save parameters and memory
1064
+ self.conv_after_body = nn.Sequential(
1065
+ nn.Conv2d(embed_dim, embed_dim // 4, 3, 1, 1), nn.LeakyReLU(negative_slope=0.2, inplace=True),
1066
+ nn.Conv2d(embed_dim // 4, embed_dim // 4, 1, 1, 0), nn.LeakyReLU(negative_slope=0.2, inplace=True),
1067
+ nn.Conv2d(embed_dim // 4, embed_dim, 3, 1, 1))
1068
+
1069
+ # ------------------------- 3, high quality image reconstruction ------------------------- #
1070
+ if self.upsampler == 'pixelshuffle':
1071
+ # for classical SR
1072
+ self.conv_before_upsample = nn.Sequential(
1073
+ nn.Conv2d(embed_dim, num_feat, 3, 1, 1), nn.LeakyReLU(inplace=True))
1074
+ self.upsample = Upsample(upscale, num_feat)
1075
+ self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)
1076
+ elif self.upsampler == 'pixelshuffledirect':
1077
+ # for lightweight SR (to save parameters)
1078
+ self.upsample = UpsampleOneStep(upscale, embed_dim, num_out_ch,
1079
+ (patches_resolution[0], patches_resolution[1]))
1080
+ elif self.upsampler == 'nearest+conv':
1081
+ # for real-world SR (less artifacts)
1082
+ assert self.upscale == 4, 'only support x4 now.'
1083
+ self.conv_before_upsample = nn.Sequential(
1084
+ nn.Conv2d(embed_dim, num_feat, 3, 1, 1), nn.LeakyReLU(inplace=True))
1085
+ self.conv_up1 = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
1086
+ self.conv_up2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
1087
+ self.conv_hr = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
1088
+ self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)
1089
+ self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
1090
+ else:
1091
+ # for image denoising and JPEG compression artifact reduction
1092
+ # self.conv_last = nn.Conv2d(embed_dim, num_out_ch, 3, 1, 1) # original code
1093
+
1094
+ # ML-SIM v1 v6
1095
+ # self.conv_last = nn.Conv2d(embed_dim, num_in_ch, 3, 1, 1)
1096
+ # self.conv_combine = nn.Conv2d(num_in_ch, num_out_ch, 3, 1, 1)
1097
+
1098
+ # ML-SIM v2,v3
1099
+ # self.conv_last = nn.Conv2d(embed_dim, num_in_ch, 3, 1, 1)
1100
+ # self.conv_combine = nn.Conv2d(num_in_ch, num_out_ch, 3, 1, 1)
1101
+ # self.axial_att_block = AxialAttentionBlock(in_channels=9, dim=256, heads=8)
1102
+
1103
+ # ML-SIM v4 v5
1104
+ # self.conv_last = nn.Conv2d(embed_dim, round(pixelshuffleFactor**2*num_in_ch), 3, 1, 1)
1105
+ # self.conv_combine = nn.Conv2d(num_in_ch, num_out_ch, 3, 1, 1)
1106
+
1107
+ # ML-SIM v7
1108
+ # self.conv_last = nn.Conv2d(embed_dim, round(3*pixelshuffleFactor**2*num_in_ch), 3, 1, 1)
1109
+ # self.conv_combine = nn.Conv2d(3*num_in_ch, num_out_ch, 3, 1, 1)
1110
+
1111
+ # ML-SIM v8
1112
+ self.conv_before_upsample = nn.Sequential(
1113
+ nn.Conv2d(embed_dim, num_feat, 3, 1, 1), nn.LeakyReLU(inplace=True))
1114
+ self.upsample = Upsample(upscale, num_feat)
1115
+ self.conv_last = nn.Conv2d(num_feat, round(pixelshuffleFactor**2*num_in_ch), 3, 1, 1)
1116
+ self.conv_combine = nn.Conv2d(num_in_ch, num_out_ch, 3, 1, 1)
1117
+
1118
+ self.task = opt.task
1119
+ if self.task == 'segment':
1120
+ self.segmentation_decode = nn.Conv2d(num_in_ch, 4, 1)
1121
+ self.vis = vis
1122
+ self.apply(self._init_weights)
1123
+
1124
+ def _init_weights(self, m):
1125
+ if isinstance(m, nn.Linear):
1126
+ trunc_normal_(m.weight, std=.02)
1127
+ if isinstance(m, nn.Linear) and m.bias is not None:
1128
+ nn.init.constant_(m.bias, 0)
1129
+ elif isinstance(m, nn.LayerNorm):
1130
+ nn.init.constant_(m.bias, 0)
1131
+ nn.init.constant_(m.weight, 1.0)
1132
+
1133
+ @torch.jit.ignore
1134
+ def no_weight_decay(self):
1135
+ return {'absolute_pos_embed'}
1136
+
1137
+ @torch.jit.ignore
1138
+ def no_weight_decay_keywords(self):
1139
+ return {'relative_position_bias_table'}
1140
+
1141
+ def forward_features(self, x):
1142
+ x_size = (x.shape[2], x.shape[3])
1143
+ # print('before patch embed',x.shape)
1144
+ x = self.patch_embed(x)
1145
+ # print('after patch embed',x.shape)
1146
+ if self.ape:
1147
+ x = x + self.absolute_pos_embed
1148
+ x = self.pos_drop(x)
1149
+
1150
+ for idx,layer in enumerate(self.layers):
1151
+ x = layer(x, x_size)
1152
+ if self.vis:
1153
+ x_unembed = self.patch_unembed(x, x_size)
1154
+ torch.save(x_unembed.detach().cpu(),'x_layer_%d.pth' % idx)
1155
+
1156
+
1157
+ x = self.norm(x) # b seq_len c
1158
+ # rint('before patch unembed',x.shape)
1159
+ x = self.patch_unembed(x, x_size)
1160
+ # print('before patch unembed',x.shape)
1161
+
1162
+ return x
1163
+
1164
+ def forward(self, x):
1165
+ # print('starting forward',x.shape)
1166
+ self.mean = self.mean.type_as(x)
1167
+ x = (x - self.mean) * self.img_range
1168
+
1169
+ if self.upsampler == 'pixelshuffle':
1170
+ # for classical SR
1171
+ x = self.conv_first(x)
1172
+ x = self.conv_after_body(self.forward_features(x)) + x
1173
+ x = self.conv_before_upsample(x)
1174
+ x = self.conv_last(self.upsample(x))
1175
+ elif self.upsampler == 'pixelshuffledirect':
1176
+ # for lightweight SR
1177
+ x = self.conv_first(x)
1178
+ x = self.conv_after_body(self.forward_features(x)) + x
1179
+ x = self.upsample(x)
1180
+ elif self.upsampler == 'nearest+conv':
1181
+ # for real-world SR
1182
+ x = self.conv_first(x)
1183
+ x = self.conv_after_body(self.forward_features(x)) + x
1184
+ x = self.conv_before_upsample(x)
1185
+ x = self.lrelu(self.conv_up1(torch.nn.functional.interpolate(x, scale_factor=2, mode='nearest')))
1186
+ x = self.lrelu(self.conv_up2(torch.nn.functional.interpolate(x, scale_factor=2, mode='nearest')))
1187
+ x = self.conv_last(self.lrelu(self.conv_hr(x)))
1188
+ else:
1189
+ # for image denoising and JPEG compression artifact reduction
1190
+
1191
+ # ML-SIM v1 v2 v3
1192
+ # x_first = self.conv_first(x)
1193
+ # res = self.conv_after_body(self.forward_features(x_first)) + x_first
1194
+ # res = self.conv_last(res)
1195
+
1196
+ # ML-SIM v1
1197
+ # x = self.conv_combine(x + res)
1198
+
1199
+ # ML-SIM v2
1200
+ # x = self.axial_att_block(x)
1201
+ # x = self.conv_combine(x + res)
1202
+
1203
+ # ML-SIM v3
1204
+ # res = self.axial_att_block(res)
1205
+ # x = self.conv_combine(x + res)
1206
+
1207
+ # ML-SIM v4 v5
1208
+ # x_encoded = self.pixelshuffle_encode(x)
1209
+ # x_first = self.conv_first(x_encoded)
1210
+ # res = self.conv_after_body(self.forward_features(x_first)) + x_first
1211
+ # res = self.conv_last(res)
1212
+ # res_decoded = self.pixelshuffle_decode(res)
1213
+ # x = self.conv_combine(x + res_decoded)
1214
+
1215
+
1216
+ # ML-SIM v6
1217
+ # x_encoded = torch.fft.fft2(x,dim=(-1,-2)).real
1218
+ # x_first = self.conv_first(x_encoded + x)
1219
+ # res = self.conv_after_body(self.forward_features(x_first)) + x_first
1220
+ # res = self.conv_last(res)
1221
+ # x = self.conv_combine(x + res)
1222
+
1223
+ # ML-SIM v7
1224
+ # x_cos = torch.cos(x)
1225
+ # x_sin = torch.sin(x)
1226
+ # x = torch.cat((x,x_cos,x_sin),dim=1)
1227
+ # x_encoded = self.pixelshuffle_encode(x)
1228
+ # x_first = self.conv_first(x_encoded)
1229
+ # res = self.conv_after_body(self.forward_features(x_first)) + x_first
1230
+ # res = self.conv_last(res)
1231
+ # res_decoded = self.pixelshuffle_decode(res)
1232
+ # x = self.conv_combine(x + res_decoded)
1233
+
1234
+ # ML-SIM v8
1235
+ x_encoded = self.pixelshuffle_encode(x)
1236
+
1237
+ # print('after pixelshuffle',x_encoded.shape)
1238
+ x_first = self.conv_first(x_encoded)
1239
+ # print('after conv first',x_first.shape)
1240
+ x_forwardfeat = self.forward_features(x_first)
1241
+ # print('after forward feat',x_forwardfeat.shape)
1242
+ res = self.conv_after_body(x_forwardfeat) + x_first
1243
+ # print('after conv after body',res.shape)
1244
+ x = self.conv_before_upsample(res)
1245
+ # print('after conv before upsample',x.shape)
1246
+ x = self.conv_last(self.upsample(x))
1247
+ # print('after conv last',x.shape)
1248
+
1249
+
1250
+ if self.task == 'segment':
1251
+ x = self.segmentation_decode(x) # assumes pixelshuffle = 1
1252
+ else:
1253
+ res_decoded = self.pixelshuffle_decode(x)
1254
+ # print('after pixel shuffle',res_decoded.shape)
1255
+ x = self.conv_combine(res_decoded)
1256
+ # print('after conv combine',x.shape)
1257
+
1258
+
1259
+ x = x / self.img_range + self.mean
1260
+
1261
+ return x
1262
+
1263
+ def flops(self):
1264
+ flops = 0
1265
+ h, w = self.patches_resolution
1266
+ flops += h * w * 3 * self.embed_dim * 9
1267
+ flops += self.patch_embed.flops()
1268
+ for layer in self.layers:
1269
+ flops += layer.flops()
1270
+ flops += h * w * 3 * self.embed_dim * self.embed_dim
1271
+ flops += self.upsample.flops()
1272
+ return flops
1273
+
1274
+
1275
+ if __name__ == '__main__':
1276
+ upscale = 4
1277
+ window_size = 8
1278
+ height = (1024 // upscale // window_size + 1) * window_size
1279
+ width = (720 // upscale // window_size + 1) * window_size
1280
+ model = SwinIR(
1281
+ upscale=2,
1282
+ img_size=(height, width),
1283
+ window_size=window_size,
1284
+ img_range=1.,
1285
+ depths=[6, 6, 6, 6],
1286
+ embed_dim=60,
1287
+ num_heads=[6, 6, 6, 6],
1288
+ mlp_ratio=2,
1289
+ upsampler='pixelshuffledirect')
1290
+ print(model)
1291
+ print(height, width, model.flops() / 1e9)
1292
+
1293
+ x = torch.randn((1, 3, height, width))
1294
+ x = model(x)
1295
+ print(x.shape)
1296
+
requirements.txt CHANGED
@@ -6,3 +6,5 @@ scikit-image
6
  opencv-python
7
  numpy
8
  matplotlib
 
 
 
6
  opencv-python
7
  numpy
8
  matplotlib
9
+ timm
10
+ einops