refine promptiqa.py
Browse files- PromptIQA/models/gc_loss.py +0 -99
- PromptIQA/models/monet_IPF.py +0 -397
- PromptIQA/models/monet_test.py +0 -389
- PromptIQA/models/monet_wo_prompt.py +0 -392
- PromptIQA/models/{monet.py → promptiqa.py} +2 -84
- PromptIQA/models/vit_base.py +0 -402
- PromptIQA/models/vit_large.py +0 -405
- PromptIQA/run_promptIQA copy.py +0 -109
- PromptIQA/run_promptIQA.py +2 -2
- PromptIQA/t.py +0 -2
- PromptIQA/test.py +0 -429
- PromptIQA/test.sh +0 -9
- best_model.pth.tar +3 -0
- get_examplt.py +0 -27
PromptIQA/models/gc_loss.py
DELETED
|
@@ -1,99 +0,0 @@
|
|
| 1 |
-
import torch.nn as nn
|
| 2 |
-
import torch
|
| 3 |
-
import numpy as np
|
| 4 |
-
|
| 5 |
-
|
| 6 |
-
class GC_Loss(nn.Module):
|
| 7 |
-
def __init__(self, queue_len=800, alpha=0.5, beta=0.5, gamma=1):
|
| 8 |
-
super(GC_Loss, self).__init__()
|
| 9 |
-
self.pred_queue = list()
|
| 10 |
-
self.gt_queue = list()
|
| 11 |
-
self.queue_len = 0
|
| 12 |
-
|
| 13 |
-
self.queue_max_len = queue_len
|
| 14 |
-
print('The queue length is: ', self.queue_max_len)
|
| 15 |
-
self.mse = torch.nn.MSELoss().cuda()
|
| 16 |
-
|
| 17 |
-
self.alpha, self.beta, self.gamma = alpha, beta, gamma
|
| 18 |
-
|
| 19 |
-
def consistency(self, pred_data, gt_data):
|
| 20 |
-
pred_one_batch, pred_queue = pred_data
|
| 21 |
-
gt_one_batch, gt_queue = gt_data
|
| 22 |
-
|
| 23 |
-
pred_mean = torch.mean(pred_queue)
|
| 24 |
-
gt_mean = torch.mean(gt_queue)
|
| 25 |
-
|
| 26 |
-
diff_pred = pred_one_batch - pred_mean
|
| 27 |
-
diff_gt = gt_one_batch - gt_mean
|
| 28 |
-
|
| 29 |
-
x1 = torch.sum(torch.mul(diff_pred, diff_gt))
|
| 30 |
-
x2_1 = torch.sqrt(torch.sum(torch.mul(diff_pred, diff_pred)))
|
| 31 |
-
x2_2 = torch.sqrt(torch.sum(torch.mul(diff_gt, diff_gt)))
|
| 32 |
-
|
| 33 |
-
return x1 / (x2_1 * x2_2)
|
| 34 |
-
|
| 35 |
-
def ppra(self, x):
|
| 36 |
-
"""
|
| 37 |
-
Pairwise Preference-based Rank Approximation
|
| 38 |
-
"""
|
| 39 |
-
|
| 40 |
-
x_bar, x_std = torch.mean(x), torch.std(x)
|
| 41 |
-
x_n = (x - x_bar) / x_std
|
| 42 |
-
x_n_T = x_n.reshape(-1, 1)
|
| 43 |
-
|
| 44 |
-
rank_x = x_n_T - x_n_T.transpose(1, 0)
|
| 45 |
-
rank_x = torch.sum(1 / 2 * (1 + torch.erf(rank_x / torch.sqrt(torch.tensor(2, dtype=torch.float)))), dim=1)
|
| 46 |
-
|
| 47 |
-
return rank_x
|
| 48 |
-
|
| 49 |
-
@torch.no_grad()
|
| 50 |
-
def enqueue(self, pred, gt):
|
| 51 |
-
bs = pred.shape[0]
|
| 52 |
-
self.queue_len = self.queue_len + bs
|
| 53 |
-
|
| 54 |
-
self.pred_queue = self.pred_queue + pred.tolist()
|
| 55 |
-
self.gt_queue = self.gt_queue + gt.cpu().detach().numpy().tolist()
|
| 56 |
-
|
| 57 |
-
if self.queue_len > self.queue_max_len:
|
| 58 |
-
self.dequeue(self.queue_len - self.queue_max_len)
|
| 59 |
-
self.queue_len = self.queue_max_len
|
| 60 |
-
|
| 61 |
-
@torch.no_grad()
|
| 62 |
-
def dequeue(self, n):
|
| 63 |
-
for _ in range(n):
|
| 64 |
-
self.pred_queue.pop(0)
|
| 65 |
-
self.gt_queue.pop(0)
|
| 66 |
-
|
| 67 |
-
def clear(self):
|
| 68 |
-
self.pred_queue.clear()
|
| 69 |
-
self.gt_queue.clear()
|
| 70 |
-
|
| 71 |
-
def forward(self, x, y):
|
| 72 |
-
x_queue = self.pred_queue.copy()
|
| 73 |
-
y_queue = self.gt_queue.copy()
|
| 74 |
-
|
| 75 |
-
x_all = torch.cat((x, torch.tensor(x_queue).cuda()), dim=0)
|
| 76 |
-
y_all = torch.cat((y, torch.tensor(y_queue).cuda()), dim=0)
|
| 77 |
-
|
| 78 |
-
PLCC = self.consistency((x, x_all), (y, y_all))
|
| 79 |
-
PGC = 1 - PLCC
|
| 80 |
-
|
| 81 |
-
rank_x = self.ppra(x_all)
|
| 82 |
-
rank_y = self.ppra(y_all)
|
| 83 |
-
SROCC = self.consistency((rank_x[:x.shape[0]], rank_x), (rank_y[:y.shape[0]], rank_y))
|
| 84 |
-
SGC = 1 - SROCC
|
| 85 |
-
|
| 86 |
-
GC = (self.alpha * PGC + self.beta * SGC + self.gamma) * self.mse(x, y)
|
| 87 |
-
self.enqueue(x, y)
|
| 88 |
-
|
| 89 |
-
return GC
|
| 90 |
-
|
| 91 |
-
|
| 92 |
-
if __name__ == '__main__':
|
| 93 |
-
gc = GC_Loss().cuda()
|
| 94 |
-
x = torch.tensor([1, 2, 3, 4, 5], dtype=torch.float).cuda()
|
| 95 |
-
y = torch.tensor([6, 7, 8, 9, 15], dtype=torch.float).cuda()
|
| 96 |
-
|
| 97 |
-
res = gc(x, y)
|
| 98 |
-
|
| 99 |
-
print(res)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
PromptIQA/models/monet_IPF.py
DELETED
|
@@ -1,397 +0,0 @@
|
|
| 1 |
-
"""
|
| 2 |
-
The completion for Mean-opinion Network(MoNet)
|
| 3 |
-
"""
|
| 4 |
-
import torch
|
| 5 |
-
import torch.nn as nn
|
| 6 |
-
import timm
|
| 7 |
-
|
| 8 |
-
from timm.models.vision_transformer import Block
|
| 9 |
-
from einops import rearrange
|
| 10 |
-
from itertools import combinations
|
| 11 |
-
|
| 12 |
-
from tqdm import tqdm
|
| 13 |
-
|
| 14 |
-
class Attention_Block(nn.Module):
|
| 15 |
-
def __init__(self, dim, drop=0.1):
|
| 16 |
-
super().__init__()
|
| 17 |
-
self.c_q = nn.Linear(dim, dim)
|
| 18 |
-
self.c_k = nn.Linear(dim, dim)
|
| 19 |
-
self.c_v = nn.Linear(dim, dim)
|
| 20 |
-
self.norm_fact = dim ** -0.5
|
| 21 |
-
self.softmax = nn.Softmax(dim=-1)
|
| 22 |
-
self.proj_drop = nn.Dropout(drop)
|
| 23 |
-
|
| 24 |
-
def forward(self, x):
|
| 25 |
-
_x = x
|
| 26 |
-
B, C, N = x.shape
|
| 27 |
-
q = self.c_q(x)
|
| 28 |
-
k = self.c_k(x)
|
| 29 |
-
v = self.c_v(x)
|
| 30 |
-
|
| 31 |
-
attn = q @ k.transpose(-2, -1) * self.norm_fact
|
| 32 |
-
attn = self.softmax(attn)
|
| 33 |
-
x = (attn @ v).transpose(1, 2).reshape(B, C, N)
|
| 34 |
-
x = self.proj_drop(x)
|
| 35 |
-
x = x + _x
|
| 36 |
-
return x
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
class Self_Attention(nn.Module):
|
| 40 |
-
""" Self attention Layer"""
|
| 41 |
-
|
| 42 |
-
def __init__(self, in_dim):
|
| 43 |
-
super(Self_Attention, self).__init__()
|
| 44 |
-
|
| 45 |
-
self.qConv = nn.Conv2d(in_channels=in_dim, out_channels=in_dim // 8, kernel_size=1)
|
| 46 |
-
self.kConv = nn.Conv2d(in_channels=in_dim, out_channels=in_dim // 8, kernel_size=1)
|
| 47 |
-
self.vConv = nn.Conv2d(in_channels=in_dim, out_channels=in_dim, kernel_size=1)
|
| 48 |
-
self.gamma = nn.Parameter(torch.zeros(1))
|
| 49 |
-
|
| 50 |
-
self.softmax = nn.Softmax(dim=-1)
|
| 51 |
-
|
| 52 |
-
def forward(self, inFeature):
|
| 53 |
-
bs, C, w, h = inFeature.size()
|
| 54 |
-
|
| 55 |
-
proj_query = self.qConv(inFeature).view(bs, -1, w * h).permute(0, 2, 1).contiguous()
|
| 56 |
-
proj_key = self.kConv(inFeature).view(bs, -1, w * h)
|
| 57 |
-
energy = torch.bmm(proj_query, proj_key)
|
| 58 |
-
attention = self.softmax(energy)
|
| 59 |
-
proj_value = self.vConv(inFeature).view(bs, -1, w * h)
|
| 60 |
-
|
| 61 |
-
out = torch.bmm(proj_value, attention.permute(0, 2, 1).contiguous())
|
| 62 |
-
out = out.view(bs, C, w, h)
|
| 63 |
-
|
| 64 |
-
out = self.gamma * out + inFeature
|
| 65 |
-
|
| 66 |
-
return out
|
| 67 |
-
|
| 68 |
-
|
| 69 |
-
class MAL(nn.Module):
|
| 70 |
-
"""
|
| 71 |
-
Multi-view Attention Learning (MAL) module
|
| 72 |
-
"""
|
| 73 |
-
|
| 74 |
-
def __init__(self, in_dim=768, feature_num=4, feature_size=28):
|
| 75 |
-
super().__init__()
|
| 76 |
-
|
| 77 |
-
self.channel_attention = Attention_Block(in_dim * feature_num) # Channel-wise self attention
|
| 78 |
-
self.feature_attention = Attention_Block(feature_size ** 2 * feature_num) # Pixel-wise self attention
|
| 79 |
-
|
| 80 |
-
# Self attention module for each input feature
|
| 81 |
-
self.attention_module = nn.ModuleList()
|
| 82 |
-
for _ in range(feature_num):
|
| 83 |
-
self.attention_module.append(Self_Attention(in_dim))
|
| 84 |
-
|
| 85 |
-
self.feature_num = feature_num
|
| 86 |
-
self.in_dim = in_dim
|
| 87 |
-
|
| 88 |
-
def forward(self, features):
|
| 89 |
-
feature = torch.tensor([]).cuda()
|
| 90 |
-
for index, _ in enumerate(features):
|
| 91 |
-
feature = torch.cat((feature, self.attention_module[index](features[index]).unsqueeze(0)), dim=0)
|
| 92 |
-
features = feature
|
| 93 |
-
|
| 94 |
-
input_tensor = rearrange(features, 'n b c w h -> b (n c) (w h)') # bs, 768 * feature_num, 28 * 28
|
| 95 |
-
bs, _, _ = input_tensor.shape # [2, 3072, 784]
|
| 96 |
-
|
| 97 |
-
in_feature = rearrange(input_tensor, 'b (w c) h -> b w (c h)', w=self.in_dim,
|
| 98 |
-
c=self.feature_num) # bs, 768, 28 * 28 * feature_num
|
| 99 |
-
feature_weight_sum = self.feature_attention(in_feature) # bs, 768, 768
|
| 100 |
-
|
| 101 |
-
in_channel = input_tensor.permute(0, 2, 1).contiguous() # bs, 28 * 28, 768 * feature_num
|
| 102 |
-
channel_weight_sum = self.channel_attention(in_channel) # bs, 28 * 28, 28 * 28
|
| 103 |
-
|
| 104 |
-
weight_sum_res = (rearrange(feature_weight_sum, 'b w (c h) -> b (w c) h', w=self.in_dim,
|
| 105 |
-
c=self.feature_num) + channel_weight_sum.permute(0, 2, 1).contiguous()) / 2 # [2, 3072, 784]
|
| 106 |
-
|
| 107 |
-
weight_sum_res = torch.mean(weight_sum_res.view(bs, self.feature_num, self.in_dim, -1), dim=1)
|
| 108 |
-
|
| 109 |
-
return weight_sum_res # bs, 768, 28 * 28
|
| 110 |
-
|
| 111 |
-
|
| 112 |
-
class SaveOutput:
|
| 113 |
-
def __init__(self):
|
| 114 |
-
self.outputs = []
|
| 115 |
-
|
| 116 |
-
def __call__(self, module, module_in, module_out):
|
| 117 |
-
self.outputs.append(module_out)
|
| 118 |
-
|
| 119 |
-
def clear(self):
|
| 120 |
-
self.outputs = []
|
| 121 |
-
|
| 122 |
-
# utils
|
| 123 |
-
@torch.no_grad()
|
| 124 |
-
def concat_all_gather(tensor):
|
| 125 |
-
"""
|
| 126 |
-
Performs all_gather operation on the provided tensors.
|
| 127 |
-
*** Warning ***: torch.distributed.all_gather has no gradient.
|
| 128 |
-
"""
|
| 129 |
-
tensors_gather = [
|
| 130 |
-
torch.ones_like(tensor) for _ in range(torch.distributed.get_world_size())
|
| 131 |
-
]
|
| 132 |
-
torch.distributed.all_gather(tensors_gather, tensor, async_op=False)
|
| 133 |
-
|
| 134 |
-
output = torch.cat(tensors_gather, dim=0)
|
| 135 |
-
return output
|
| 136 |
-
class Attention(nn.Module):
|
| 137 |
-
def __init__(self, dim, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0.):
|
| 138 |
-
super().__init__()
|
| 139 |
-
assert dim % num_heads == 0, 'dim should be divisible by num_heads'
|
| 140 |
-
self.num_heads = num_heads
|
| 141 |
-
head_dim = dim // num_heads
|
| 142 |
-
self.scale = head_dim ** -0.5
|
| 143 |
-
|
| 144 |
-
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
| 145 |
-
self.attn_drop = nn.Dropout(attn_drop)
|
| 146 |
-
self.proj = nn.Linear(dim, dim)
|
| 147 |
-
self.proj_drop = nn.Dropout(proj_drop)
|
| 148 |
-
|
| 149 |
-
def forward(self, x):
|
| 150 |
-
B, N, C = x.shape
|
| 151 |
-
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
|
| 152 |
-
q, k, v = qkv.unbind(0) # make torchscript happy (cannot use tensor as tuple)
|
| 153 |
-
|
| 154 |
-
attn = (q @ k.transpose(-2, -1)) * self.scale
|
| 155 |
-
attn = attn.softmax(dim=-1)
|
| 156 |
-
attn = self.attn_drop(attn)
|
| 157 |
-
|
| 158 |
-
x = (attn @ v).transpose(1, 2).reshape(B, N, C)
|
| 159 |
-
x = self.proj(x)
|
| 160 |
-
x = self.proj_drop(x)
|
| 161 |
-
return x
|
| 162 |
-
import torch
|
| 163 |
-
from functools import partial
|
| 164 |
-
class MoNet(nn.Module):
|
| 165 |
-
def __init__(self, patch_size=8, drop=0.1, dim_mlp=768, img_size=224):
|
| 166 |
-
super().__init__()
|
| 167 |
-
self.img_size = img_size
|
| 168 |
-
self.input_size = img_size // patch_size
|
| 169 |
-
self.dim_mlp = dim_mlp
|
| 170 |
-
|
| 171 |
-
self.vit = timm.create_model('vit_base_patch8_224', pretrained=True)
|
| 172 |
-
self.vit.norm = nn.Identity()
|
| 173 |
-
self.vit.head = nn.Identity()
|
| 174 |
-
|
| 175 |
-
self.save_output = SaveOutput()
|
| 176 |
-
|
| 177 |
-
# Register Hooks
|
| 178 |
-
hook_handles = []
|
| 179 |
-
for layer in self.vit.modules():
|
| 180 |
-
if isinstance(layer, Block):
|
| 181 |
-
handle = layer.register_forward_hook(self.save_output)
|
| 182 |
-
hook_handles.append(handle)
|
| 183 |
-
|
| 184 |
-
self.MALs = nn.ModuleList()
|
| 185 |
-
for _ in range(3):
|
| 186 |
-
self.MALs.append(MAL())
|
| 187 |
-
|
| 188 |
-
# Image Quality Score Regression
|
| 189 |
-
self.fusion_mal = MAL(feature_num=3)
|
| 190 |
-
self.block = Block(dim_mlp, 12)
|
| 191 |
-
self.cnn = nn.Sequential(
|
| 192 |
-
nn.Conv2d(dim_mlp, 256, 5),
|
| 193 |
-
nn.BatchNorm2d(256),
|
| 194 |
-
nn.ReLU(inplace=True),
|
| 195 |
-
nn.AvgPool2d((2, 2)),
|
| 196 |
-
nn.Conv2d(256, 128, 3),
|
| 197 |
-
nn.BatchNorm2d(128),
|
| 198 |
-
nn.ReLU(inplace=True),
|
| 199 |
-
nn.AvgPool2d((2, 2)),
|
| 200 |
-
nn.Conv2d(128, 128, 3),
|
| 201 |
-
nn.BatchNorm2d(128),
|
| 202 |
-
nn.ReLU(inplace=True),
|
| 203 |
-
nn.AvgPool2d((3, 3)),
|
| 204 |
-
)
|
| 205 |
-
|
| 206 |
-
self.i_p_fusion = nn.Sequential(
|
| 207 |
-
Block(128, 4),
|
| 208 |
-
Block(128, 4),
|
| 209 |
-
Block(128, 4),
|
| 210 |
-
)
|
| 211 |
-
self.mlp = nn.Sequential(
|
| 212 |
-
nn.Linear(128, 64),
|
| 213 |
-
nn.GELU(),
|
| 214 |
-
nn.Linear(64, 128),
|
| 215 |
-
)
|
| 216 |
-
|
| 217 |
-
self.prompt_fusion = nn.Sequential(
|
| 218 |
-
Block(128, 4),
|
| 219 |
-
Block(128, 4),
|
| 220 |
-
Block(128, 4),
|
| 221 |
-
)
|
| 222 |
-
|
| 223 |
-
dpr = [x.item() for x in torch.linspace(0, 0, 8)] # stochastic depth decay rule
|
| 224 |
-
self.blocks = nn.Sequential(*[
|
| 225 |
-
Block(
|
| 226 |
-
dim=128, num_heads=4, mlp_ratio=4, qkv_bias=True, drop=0,
|
| 227 |
-
attn_drop=0, drop_path=dpr[i], norm_layer=partial(nn.LayerNorm, eps=1e-6), act_layer=nn.GELU)
|
| 228 |
-
for i in range(8)])
|
| 229 |
-
self.norm = nn.LayerNorm(128)
|
| 230 |
-
|
| 231 |
-
self.score_block = nn.Sequential(
|
| 232 |
-
nn.Linear(128, 128 // 2),
|
| 233 |
-
nn.ReLU(),
|
| 234 |
-
nn.Dropout(drop),
|
| 235 |
-
nn.Linear(128 // 2, 1),
|
| 236 |
-
nn.Sigmoid()
|
| 237 |
-
)
|
| 238 |
-
|
| 239 |
-
self.prompt_feature = {}
|
| 240 |
-
|
| 241 |
-
|
| 242 |
-
|
| 243 |
-
@torch.no_grad()
|
| 244 |
-
def clear(self):
|
| 245 |
-
self.prompt_feature = {}
|
| 246 |
-
|
| 247 |
-
@torch.no_grad()
|
| 248 |
-
def inference(self, x, data_type):
|
| 249 |
-
prompt_feature = self.prompt_feature[data_type] # 1, n, 128
|
| 250 |
-
|
| 251 |
-
_x = self.vit(x)
|
| 252 |
-
x = self.extract_feature(self.save_output) # bs, 28 * 28, 768 * 4
|
| 253 |
-
self.save_output.outputs.clear()
|
| 254 |
-
|
| 255 |
-
x = x.permute(0, 2, 1).contiguous() # bs, 768 * 4, 28 * 28
|
| 256 |
-
x = rearrange(x, 'b (d n) (w h) -> b d n w h', d=4, n=self.dim_mlp, w=self.input_size, h=self.input_size) # bs, 4, 768, 28, 28
|
| 257 |
-
x = x.permute(1, 0, 2, 3, 4).contiguous() # bs, 4, 768, 28 * 28
|
| 258 |
-
|
| 259 |
-
# Different Opinion Features (DOF)
|
| 260 |
-
DOF = torch.tensor([]).cuda()
|
| 261 |
-
for index, _ in enumerate(self.MALs):
|
| 262 |
-
DOF = torch.cat((DOF, self.MALs[index](x).unsqueeze(0)), dim=0)
|
| 263 |
-
DOF = rearrange(DOF, 'n c d (w h) -> n c d w h', w=self.input_size, h=self.input_size) # M, bs, 768, 28, 28
|
| 264 |
-
|
| 265 |
-
# Image Quality Score Regression
|
| 266 |
-
fusion_mal = self.fusion_mal(DOF).permute(0, 2, 1).contiguous() # bs, 28 * 28 768
|
| 267 |
-
IQ_feature = self.block(fusion_mal).permute(0, 2, 1).contiguous() # bs, 768, 28 * 28
|
| 268 |
-
IQ_feature = rearrange(IQ_feature, 'c d (w h) -> c d w h', w=self.input_size, h=self.input_size) # bs, 768, 28, 28
|
| 269 |
-
img_feature = self.cnn(IQ_feature).squeeze(-1).squeeze(-1).unsqueeze(1) # bs, 1, 128
|
| 270 |
-
|
| 271 |
-
prompt_feature = prompt_feature.repeat(img_feature.shape[0], 1, 1) # bs, n, 128
|
| 272 |
-
prompt_feature = self.prompt_fusion(prompt_feature) # bs, n, 128
|
| 273 |
-
|
| 274 |
-
fusion = self.blocks(torch.cat((img_feature, prompt_feature), dim=1))[:, 0, :] # bs, 2, 1
|
| 275 |
-
# fusion = self.norm(fusion)[:, 0, :]
|
| 276 |
-
# fusion = self.score_block(fusion)
|
| 277 |
-
|
| 278 |
-
# # iq_res = torch.mean(fusion, dim=1).view(-1)
|
| 279 |
-
# iq_res = fusion[:, 0].view(-1)
|
| 280 |
-
|
| 281 |
-
return fusion
|
| 282 |
-
|
| 283 |
-
@torch.no_grad()
|
| 284 |
-
def check_prompt(self, data_type):
|
| 285 |
-
return data_type in self.prompt_feature
|
| 286 |
-
|
| 287 |
-
@torch.no_grad()
|
| 288 |
-
def forward_prompt(self, x, score, data_type):
|
| 289 |
-
if data_type in self.prompt_feature:
|
| 290 |
-
return
|
| 291 |
-
_x = self.vit(x)
|
| 292 |
-
x = self.extract_feature(self.save_output) # bs, 28 * 28, 768 * 4
|
| 293 |
-
self.save_output.outputs.clear()
|
| 294 |
-
|
| 295 |
-
x = x.permute(0, 2, 1).contiguous() # bs, 768 * 4, 28 * 28
|
| 296 |
-
x = rearrange(x, 'b (d n) (w h) -> b d n w h', d=4, n=self.dim_mlp, w=self.input_size, h=self.input_size) # bs, 4, 768, 28, 28
|
| 297 |
-
x = x.permute(1, 0, 2, 3, 4).contiguous() # 4, bs, 768, 28, 28
|
| 298 |
-
|
| 299 |
-
# Different Opinion Features (DOF)
|
| 300 |
-
DOF = torch.tensor([]).cuda()
|
| 301 |
-
for index, _ in enumerate(self.MALs):
|
| 302 |
-
DOF = torch.cat((DOF, self.MALs[index](x).unsqueeze(0)), dim=0)
|
| 303 |
-
DOF = rearrange(DOF, 'n c d (w h) -> n c d w h', w=self.input_size, h=self.input_size) # M, bs, 768, 28, 28
|
| 304 |
-
|
| 305 |
-
# Image Quality Score Regression
|
| 306 |
-
fusion_mal = self.fusion_mal(DOF).permute(0, 2, 1).contiguous() # bs, 28 * 28 768
|
| 307 |
-
IQ_feature = self.block(fusion_mal).permute(0, 2, 1).contiguous() # bs, 768, 28 * 28
|
| 308 |
-
IQ_feature = rearrange(IQ_feature, 'c d (w h) -> c d w h', w=self.input_size, h=self.input_size) # bs, 768, 28, 28
|
| 309 |
-
img_feature = self.cnn(IQ_feature).squeeze(-1).squeeze(-1).unsqueeze(1) # bs, 1, 128
|
| 310 |
-
|
| 311 |
-
# 分数线性变换为128维
|
| 312 |
-
# score_feature = self.score_projection(score) # bs, 128
|
| 313 |
-
score_feature = score.expand(-1, 128)
|
| 314 |
-
|
| 315 |
-
# img_feature 和 score_feature融合得到 funsion_feature
|
| 316 |
-
funsion_feature = self.i_p_fusion(torch.cat((img_feature, score_feature.unsqueeze(1)), dim=1)) # bs, 2, 128
|
| 317 |
-
funsion_feature = self.mlp(torch.mean(funsion_feature, dim=1)).unsqueeze(0) # 1, n, 128
|
| 318 |
-
|
| 319 |
-
# print('Load Prompt For Testing.', funsion_feature.shape)
|
| 320 |
-
# self.prompt_feature = funsion_feature.clone()
|
| 321 |
-
self.prompt_feature[data_type] = funsion_feature.clone()
|
| 322 |
-
|
| 323 |
-
def forward(self, x, score):
|
| 324 |
-
_x = self.vit(x)
|
| 325 |
-
x = self.extract_feature(self.save_output) # bs, 28 * 28, 768 * 4
|
| 326 |
-
self.save_output.outputs.clear()
|
| 327 |
-
|
| 328 |
-
x = x.permute(0, 2, 1).contiguous() # bs, 768 * 4, 28 * 28
|
| 329 |
-
x = rearrange(x, 'b (d n) (w h) -> b d n w h', d=4, n=self.dim_mlp, w=self.input_size, h=self.input_size) # bs, 4, 768, 28, 28
|
| 330 |
-
x = x.permute(1, 0, 2, 3, 4).contiguous() # 4, bs, 768, 28, 28
|
| 331 |
-
|
| 332 |
-
# Different Opinion Features (DOF)
|
| 333 |
-
DOF = torch.tensor([]).cuda()
|
| 334 |
-
for index, _ in enumerate(self.MALs):
|
| 335 |
-
DOF = torch.cat((DOF, self.MALs[index](x).unsqueeze(0)), dim=0)
|
| 336 |
-
DOF = rearrange(DOF, 'n c d (w h) -> n c d w h', w=self.input_size, h=self.input_size) # M, bs, 768, 28, 28
|
| 337 |
-
|
| 338 |
-
# Image Quality Score Regression
|
| 339 |
-
fusion_mal = self.fusion_mal(DOF).permute(0, 2, 1).contiguous() # bs, 28 * 28 768
|
| 340 |
-
IQ_feature = self.block(fusion_mal).permute(0, 2, 1).contiguous() # bs, 768, 28 * 28
|
| 341 |
-
IQ_feature = rearrange(IQ_feature, 'c d (w h) -> c d w h', w=self.input_size, h=self.input_size) # bs, 768, 28, 28
|
| 342 |
-
img_feature = self.cnn(IQ_feature).squeeze(-1).squeeze(-1).unsqueeze(1) # bs, 1, 128
|
| 343 |
-
|
| 344 |
-
# 分数线性变换为128维
|
| 345 |
-
# score_feature = self.score_projection(score) # bs, 128
|
| 346 |
-
score_feature = score.expand(-1, 128) # bs, 128
|
| 347 |
-
|
| 348 |
-
# img_feature 和 score_feature融合得到 funsion_feature
|
| 349 |
-
funsion_feature = self.i_p_fusion(torch.cat((img_feature, score_feature.unsqueeze(1)), dim=1)) # bs, 2, 128
|
| 350 |
-
funsion_feature = self.mlp(torch.mean(funsion_feature, dim=1)) #bs, 128
|
| 351 |
-
funsion_feature = self.expand(funsion_feature) # bs, bs - 1, 128
|
| 352 |
-
funsion_feature = self.prompt_fusion(funsion_feature) # bs, bs - 1, 128
|
| 353 |
-
|
| 354 |
-
fusion = self.blocks(torch.cat((img_feature, funsion_feature), dim=1)) # bs, 2, 1
|
| 355 |
-
fusion = self.norm(fusion)
|
| 356 |
-
fusion = self.score_block(fusion)
|
| 357 |
-
# iq_res = torch.mean(fusion, dim=1).view(-1)
|
| 358 |
-
iq_res = fusion[:, 0].view(-1)
|
| 359 |
-
|
| 360 |
-
# differ_fusion = self.diff_block(torch.cat((img_feature, funsion_feature), dim=1)) # bs, n, 1
|
| 361 |
-
# differ_iq_res = torch.mean(differ_fusion, dim=1).view(-1)
|
| 362 |
-
|
| 363 |
-
gt_res = score.view(-1)
|
| 364 |
-
# diff_gt_res = 1 - score.view(-1)
|
| 365 |
-
|
| 366 |
-
return (iq_res, 'differ_iq_res'), (gt_res, 'diff_gt_res')
|
| 367 |
-
|
| 368 |
-
def extract_feature(self, save_output, block_index=[2, 5, 8, 11]):
|
| 369 |
-
x1 = save_output.outputs[block_index[0]][:, 1:]
|
| 370 |
-
x2 = save_output.outputs[block_index[1]][:, 1:]
|
| 371 |
-
x3 = save_output.outputs[block_index[2]][:, 1:]
|
| 372 |
-
x4 = save_output.outputs[block_index[3]][:, 1:]
|
| 373 |
-
x = torch.cat((x1, x2, x3, x4), dim=2)
|
| 374 |
-
return x
|
| 375 |
-
|
| 376 |
-
def expand(self, A):
|
| 377 |
-
A_expanded = A.unsqueeze(0).expand(A.size(0), -1, -1)
|
| 378 |
-
|
| 379 |
-
B = None
|
| 380 |
-
for index, i in enumerate(A_expanded):
|
| 381 |
-
rmv = torch.cat((i[:index], i[index + 1:])).unsqueeze(0)
|
| 382 |
-
if B is None:
|
| 383 |
-
B = rmv
|
| 384 |
-
else:
|
| 385 |
-
B = torch.cat((B, rmv), dim=0)
|
| 386 |
-
|
| 387 |
-
return B
|
| 388 |
-
|
| 389 |
-
if __name__ == '__main__':
|
| 390 |
-
in_feature = torch.zeros((10, 3, 224, 224)).cuda()
|
| 391 |
-
gt_feature = torch.tensor([[0, 100, 1], [0, 100, 2], [0, 100, 3], [0, 100, 4], [0, 100, 5], [0, 100, 6], [0, 100, 7], [0, 100, 8], [0, 100, 9], [0, 100, 10]], dtype=torch.float).cuda()
|
| 392 |
-
model = MoNet().cuda()
|
| 393 |
-
|
| 394 |
-
iq_res, gt_res = model(in_feature, gt_feature)
|
| 395 |
-
|
| 396 |
-
print(iq_res.shape)
|
| 397 |
-
print(gt_res.shape)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
PromptIQA/models/monet_test.py
DELETED
|
@@ -1,389 +0,0 @@
|
|
| 1 |
-
"""
|
| 2 |
-
The completion for Mean-opinion Network(MoNet)
|
| 3 |
-
"""
|
| 4 |
-
import torch
|
| 5 |
-
import torch.nn as nn
|
| 6 |
-
import timm
|
| 7 |
-
|
| 8 |
-
from timm.models.vision_transformer import Block
|
| 9 |
-
from einops import rearrange
|
| 10 |
-
from itertools import combinations
|
| 11 |
-
|
| 12 |
-
from tqdm import tqdm
|
| 13 |
-
|
| 14 |
-
class Attention_Block(nn.Module):
|
| 15 |
-
def __init__(self, dim, drop=0.1):
|
| 16 |
-
super().__init__()
|
| 17 |
-
self.c_q = nn.Linear(dim, dim)
|
| 18 |
-
self.c_k = nn.Linear(dim, dim)
|
| 19 |
-
self.c_v = nn.Linear(dim, dim)
|
| 20 |
-
self.norm_fact = dim ** -0.5
|
| 21 |
-
self.softmax = nn.Softmax(dim=-1)
|
| 22 |
-
self.proj_drop = nn.Dropout(drop)
|
| 23 |
-
|
| 24 |
-
def forward(self, x):
|
| 25 |
-
_x = x
|
| 26 |
-
B, C, N = x.shape
|
| 27 |
-
q = self.c_q(x)
|
| 28 |
-
k = self.c_k(x)
|
| 29 |
-
v = self.c_v(x)
|
| 30 |
-
|
| 31 |
-
attn = q @ k.transpose(-2, -1) * self.norm_fact
|
| 32 |
-
attn = self.softmax(attn)
|
| 33 |
-
x = (attn @ v).transpose(1, 2).reshape(B, C, N)
|
| 34 |
-
x = self.proj_drop(x)
|
| 35 |
-
x = x + _x
|
| 36 |
-
return x
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
class Self_Attention(nn.Module):
|
| 40 |
-
""" Self attention Layer"""
|
| 41 |
-
|
| 42 |
-
def __init__(self, in_dim):
|
| 43 |
-
super(Self_Attention, self).__init__()
|
| 44 |
-
|
| 45 |
-
self.qConv = nn.Conv2d(in_channels=in_dim, out_channels=in_dim // 8, kernel_size=1)
|
| 46 |
-
self.kConv = nn.Conv2d(in_channels=in_dim, out_channels=in_dim // 8, kernel_size=1)
|
| 47 |
-
self.vConv = nn.Conv2d(in_channels=in_dim, out_channels=in_dim, kernel_size=1)
|
| 48 |
-
self.gamma = nn.Parameter(torch.zeros(1))
|
| 49 |
-
|
| 50 |
-
self.softmax = nn.Softmax(dim=-1)
|
| 51 |
-
|
| 52 |
-
def forward(self, inFeature):
|
| 53 |
-
bs, C, w, h = inFeature.size()
|
| 54 |
-
|
| 55 |
-
proj_query = self.qConv(inFeature).view(bs, -1, w * h).permute(0, 2, 1).contiguous()
|
| 56 |
-
proj_key = self.kConv(inFeature).view(bs, -1, w * h)
|
| 57 |
-
energy = torch.bmm(proj_query, proj_key)
|
| 58 |
-
attention = self.softmax(energy)
|
| 59 |
-
proj_value = self.vConv(inFeature).view(bs, -1, w * h)
|
| 60 |
-
|
| 61 |
-
out = torch.bmm(proj_value, attention.permute(0, 2, 1).contiguous())
|
| 62 |
-
out = out.view(bs, C, w, h)
|
| 63 |
-
|
| 64 |
-
out = self.gamma * out + inFeature
|
| 65 |
-
|
| 66 |
-
return out
|
| 67 |
-
|
| 68 |
-
|
| 69 |
-
class MAL(nn.Module):
|
| 70 |
-
"""
|
| 71 |
-
Multi-view Attention Learning (MAL) module
|
| 72 |
-
"""
|
| 73 |
-
|
| 74 |
-
def __init__(self, in_dim=768, feature_num=4, feature_size=28):
|
| 75 |
-
super().__init__()
|
| 76 |
-
|
| 77 |
-
self.channel_attention = Attention_Block(in_dim * feature_num) # Channel-wise self attention
|
| 78 |
-
self.feature_attention = Attention_Block(feature_size ** 2 * feature_num) # Pixel-wise self attention
|
| 79 |
-
|
| 80 |
-
# Self attention module for each input feature
|
| 81 |
-
self.attention_module = nn.ModuleList()
|
| 82 |
-
for _ in range(feature_num):
|
| 83 |
-
self.attention_module.append(Self_Attention(in_dim))
|
| 84 |
-
|
| 85 |
-
self.feature_num = feature_num
|
| 86 |
-
self.in_dim = in_dim
|
| 87 |
-
|
| 88 |
-
def forward(self, features):
|
| 89 |
-
feature = torch.tensor([]).cuda()
|
| 90 |
-
for index, _ in enumerate(features):
|
| 91 |
-
feature = torch.cat((feature, self.attention_module[index](features[index]).unsqueeze(0)), dim=0)
|
| 92 |
-
features = feature
|
| 93 |
-
|
| 94 |
-
input_tensor = rearrange(features, 'n b c w h -> b (n c) (w h)') # bs, 768 * feature_num, 28 * 28
|
| 95 |
-
bs, _, _ = input_tensor.shape # [2, 3072, 784]
|
| 96 |
-
|
| 97 |
-
in_feature = rearrange(input_tensor, 'b (w c) h -> b w (c h)', w=self.in_dim,
|
| 98 |
-
c=self.feature_num) # bs, 768, 28 * 28 * feature_num
|
| 99 |
-
feature_weight_sum = self.feature_attention(in_feature) # bs, 768, 768
|
| 100 |
-
|
| 101 |
-
in_channel = input_tensor.permute(0, 2, 1).contiguous() # bs, 28 * 28, 768 * feature_num
|
| 102 |
-
channel_weight_sum = self.channel_attention(in_channel) # bs, 28 * 28, 28 * 28
|
| 103 |
-
|
| 104 |
-
weight_sum_res = (rearrange(feature_weight_sum, 'b w (c h) -> b (w c) h', w=self.in_dim,
|
| 105 |
-
c=self.feature_num) + channel_weight_sum.permute(0, 2, 1).contiguous()) / 2 # [2, 3072, 784]
|
| 106 |
-
|
| 107 |
-
weight_sum_res = torch.mean(weight_sum_res.view(bs, self.feature_num, self.in_dim, -1), dim=1)
|
| 108 |
-
|
| 109 |
-
return weight_sum_res # bs, 768, 28 * 28
|
| 110 |
-
|
| 111 |
-
|
| 112 |
-
class SaveOutput:
|
| 113 |
-
def __init__(self):
|
| 114 |
-
self.outputs = []
|
| 115 |
-
|
| 116 |
-
def __call__(self, module, module_in, module_out):
|
| 117 |
-
self.outputs.append(module_out)
|
| 118 |
-
|
| 119 |
-
def clear(self):
|
| 120 |
-
self.outputs = []
|
| 121 |
-
|
| 122 |
-
# utils
|
| 123 |
-
@torch.no_grad()
|
| 124 |
-
def concat_all_gather(tensor):
|
| 125 |
-
"""
|
| 126 |
-
Performs all_gather operation on the provided tensors.
|
| 127 |
-
*** Warning ***: torch.distributed.all_gather has no gradient.
|
| 128 |
-
"""
|
| 129 |
-
tensors_gather = [
|
| 130 |
-
torch.ones_like(tensor) for _ in range(torch.distributed.get_world_size())
|
| 131 |
-
]
|
| 132 |
-
torch.distributed.all_gather(tensors_gather, tensor, async_op=False)
|
| 133 |
-
|
| 134 |
-
output = torch.cat(tensors_gather, dim=0)
|
| 135 |
-
return output
|
| 136 |
-
class Attention(nn.Module):
|
| 137 |
-
def __init__(self, dim, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0.):
|
| 138 |
-
super().__init__()
|
| 139 |
-
assert dim % num_heads == 0, 'dim should be divisible by num_heads'
|
| 140 |
-
self.num_heads = num_heads
|
| 141 |
-
head_dim = dim // num_heads
|
| 142 |
-
self.scale = head_dim ** -0.5
|
| 143 |
-
|
| 144 |
-
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
| 145 |
-
self.attn_drop = nn.Dropout(attn_drop)
|
| 146 |
-
self.proj = nn.Linear(dim, dim)
|
| 147 |
-
self.proj_drop = nn.Dropout(proj_drop)
|
| 148 |
-
|
| 149 |
-
def forward(self, x):
|
| 150 |
-
B, N, C = x.shape
|
| 151 |
-
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
|
| 152 |
-
q, k, v = qkv.unbind(0) # make torchscript happy (cannot use tensor as tuple)
|
| 153 |
-
|
| 154 |
-
attn = (q @ k.transpose(-2, -1)) * self.scale
|
| 155 |
-
attn = attn.softmax(dim=-1)
|
| 156 |
-
attn = self.attn_drop(attn)
|
| 157 |
-
|
| 158 |
-
x = (attn @ v).transpose(1, 2).reshape(B, N, C)
|
| 159 |
-
x = self.proj(x)
|
| 160 |
-
x = self.proj_drop(x)
|
| 161 |
-
return x
|
| 162 |
-
class MoNet(nn.Module):
|
| 163 |
-
def __init__(self, patch_size=8, drop=0.1, dim_mlp=768, img_size=224):
|
| 164 |
-
super().__init__()
|
| 165 |
-
self.img_size = img_size
|
| 166 |
-
self.input_size = img_size // patch_size
|
| 167 |
-
self.dim_mlp = dim_mlp
|
| 168 |
-
|
| 169 |
-
self.vit = timm.create_model('vit_base_patch8_224', pretrained=True)
|
| 170 |
-
self.vit.norm = nn.Identity()
|
| 171 |
-
self.vit.head = nn.Identity()
|
| 172 |
-
|
| 173 |
-
self.save_output = SaveOutput()
|
| 174 |
-
|
| 175 |
-
# Register Hooks
|
| 176 |
-
hook_handles = []
|
| 177 |
-
for layer in self.vit.modules():
|
| 178 |
-
if isinstance(layer, Block):
|
| 179 |
-
handle = layer.register_forward_hook(self.save_output)
|
| 180 |
-
hook_handles.append(handle)
|
| 181 |
-
|
| 182 |
-
self.MALs = nn.ModuleList()
|
| 183 |
-
for _ in range(3):
|
| 184 |
-
self.MALs.append(MAL())
|
| 185 |
-
|
| 186 |
-
# Image Quality Score Regression
|
| 187 |
-
self.fusion_mal = MAL(feature_num=3)
|
| 188 |
-
self.block = Block(dim_mlp, 12)
|
| 189 |
-
self.cnn = nn.Sequential(
|
| 190 |
-
nn.Conv2d(dim_mlp, 256, 5),
|
| 191 |
-
nn.BatchNorm2d(256),
|
| 192 |
-
nn.ReLU(inplace=True),
|
| 193 |
-
nn.AvgPool2d((2, 2)),
|
| 194 |
-
nn.Conv2d(256, 128, 3),
|
| 195 |
-
nn.BatchNorm2d(128),
|
| 196 |
-
nn.ReLU(inplace=True),
|
| 197 |
-
nn.AvgPool2d((2, 2)),
|
| 198 |
-
nn.Conv2d(128, 128, 3),
|
| 199 |
-
nn.BatchNorm2d(128),
|
| 200 |
-
nn.ReLU(inplace=True),
|
| 201 |
-
nn.AvgPool2d((3, 3)),
|
| 202 |
-
)
|
| 203 |
-
|
| 204 |
-
# self.score_projection = nn.Sequential(
|
| 205 |
-
# nn.Linear(1, 64),
|
| 206 |
-
# nn.GELU(),
|
| 207 |
-
# nn.Linear(64, 128),
|
| 208 |
-
# )
|
| 209 |
-
|
| 210 |
-
# self.i_p_fusion = nn.Sequential(
|
| 211 |
-
# Block(128, 8),
|
| 212 |
-
# Block(128, 8),
|
| 213 |
-
# Block(128, 8),
|
| 214 |
-
# )
|
| 215 |
-
self.i_p_fusion = nn.Sequential(
|
| 216 |
-
Block(128, 4),
|
| 217 |
-
Block(128, 4),
|
| 218 |
-
Block(128, 4),
|
| 219 |
-
)
|
| 220 |
-
self.mlp = nn.Sequential(
|
| 221 |
-
nn.Linear(128, 64),
|
| 222 |
-
nn.GELU(),
|
| 223 |
-
nn.Linear(64, 128),
|
| 224 |
-
)
|
| 225 |
-
|
| 226 |
-
self.score_block = nn.Sequential(
|
| 227 |
-
Block(128, 4),
|
| 228 |
-
Block(128, 4),
|
| 229 |
-
# Block(128, 4),
|
| 230 |
-
nn.Linear(128, 128 // 2),
|
| 231 |
-
nn.ReLU(),
|
| 232 |
-
nn.Dropout(drop),
|
| 233 |
-
nn.Linear(128 // 2, 1),
|
| 234 |
-
nn.Sigmoid()
|
| 235 |
-
)
|
| 236 |
-
|
| 237 |
-
# self.diff_block = nn.Sequential(
|
| 238 |
-
# Block(128, 8),
|
| 239 |
-
# Block(128, 8),
|
| 240 |
-
# Block(128, 8),
|
| 241 |
-
# nn.Linear(128, 64),
|
| 242 |
-
# nn.GELU(),
|
| 243 |
-
# nn.Linear(64, 1),
|
| 244 |
-
# )
|
| 245 |
-
self.prompt_feature = None
|
| 246 |
-
|
| 247 |
-
@torch.no_grad()
|
| 248 |
-
def clear(self):
|
| 249 |
-
self.prompt_feature = None
|
| 250 |
-
|
| 251 |
-
@torch.no_grad()
|
| 252 |
-
def inference(self, x):
|
| 253 |
-
prompt_feature = self.prompt_feature # 1, n, 128
|
| 254 |
-
|
| 255 |
-
_x = self.vit(x)
|
| 256 |
-
x = self.extract_feature(self.save_output) # bs, 28 * 28, 768 * 4
|
| 257 |
-
self.save_output.outputs.clear()
|
| 258 |
-
|
| 259 |
-
x = x.permute(0, 2, 1).contiguous() # bs, 768 * 4, 28 * 28
|
| 260 |
-
x = rearrange(x, 'b (d n) (w h) -> b d n w h', d=4, n=self.dim_mlp, w=self.input_size, h=self.input_size) # bs, 4, 768, 28, 28
|
| 261 |
-
x = x.permute(1, 0, 2, 3, 4).contiguous() # bs, 4, 768, 28 * 28
|
| 262 |
-
|
| 263 |
-
# Different Opinion Features (DOF)
|
| 264 |
-
DOF = torch.tensor([]).cuda()
|
| 265 |
-
for index, _ in enumerate(self.MALs):
|
| 266 |
-
DOF = torch.cat((DOF, self.MALs[index](x).unsqueeze(0)), dim=0)
|
| 267 |
-
DOF = rearrange(DOF, 'n c d (w h) -> n c d w h', w=self.input_size, h=self.input_size) # M, bs, 768, 28, 28
|
| 268 |
-
|
| 269 |
-
# Image Quality Score Regression
|
| 270 |
-
fusion_mal = self.fusion_mal(DOF).permute(0, 2, 1).contiguous() # bs, 28 * 28 768
|
| 271 |
-
IQ_feature = self.block(fusion_mal).permute(0, 2, 1).contiguous() # bs, 768, 28 * 28
|
| 272 |
-
IQ_feature = rearrange(IQ_feature, 'c d (w h) -> c d w h', w=self.input_size, h=self.input_size) # bs, 768, 28, 28
|
| 273 |
-
img_feature = self.cnn(IQ_feature).squeeze(-1).squeeze(-1).unsqueeze(1) # bs, 1, 128
|
| 274 |
-
|
| 275 |
-
prompt_feature = prompt_feature.repeat(img_feature.shape[0], 1, 1) # bs, n, 128
|
| 276 |
-
|
| 277 |
-
fusion = self.score_block(torch.cat((img_feature, prompt_feature), dim=1)) # bs, n, 1
|
| 278 |
-
|
| 279 |
-
# iq_res = torch.mean(fusion, dim=1).view(-1)
|
| 280 |
-
iq_res = fusion[:, 0].view(-1)
|
| 281 |
-
|
| 282 |
-
return iq_res
|
| 283 |
-
|
| 284 |
-
def extract_feature(self, save_output, block_index=[2, 5, 8, 11]):
|
| 285 |
-
x1 = save_output.outputs[block_index[0]][:, 1:]
|
| 286 |
-
x2 = save_output.outputs[block_index[1]][:, 1:]
|
| 287 |
-
x3 = save_output.outputs[block_index[2]][:, 1:]
|
| 288 |
-
x4 = save_output.outputs[block_index[3]][:, 1:]
|
| 289 |
-
x = torch.cat((x1, x2, x3, x4), dim=2)
|
| 290 |
-
return x
|
| 291 |
-
|
| 292 |
-
@torch.no_grad()
|
| 293 |
-
def forward_prompt(self, x, score):
|
| 294 |
-
_x = self.vit(x)
|
| 295 |
-
x = self.extract_feature(self.save_output) # bs, 28 * 28, 768 * 4
|
| 296 |
-
self.save_output.outputs.clear()
|
| 297 |
-
|
| 298 |
-
x = x.permute(0, 2, 1).contiguous() # bs, 768 * 4, 28 * 28
|
| 299 |
-
x = rearrange(x, 'b (d n) (w h) -> b d n w h', d=4, n=self.dim_mlp, w=self.input_size, h=self.input_size) # bs, 4, 768, 28, 28
|
| 300 |
-
x = x.permute(1, 0, 2, 3, 4).contiguous() # 4, bs, 768, 28, 28
|
| 301 |
-
|
| 302 |
-
# Different Opinion Features (DOF)
|
| 303 |
-
DOF = torch.tensor([]).cuda()
|
| 304 |
-
for index, _ in enumerate(self.MALs):
|
| 305 |
-
DOF = torch.cat((DOF, self.MALs[index](x).unsqueeze(0)), dim=0)
|
| 306 |
-
DOF = rearrange(DOF, 'n c d (w h) -> n c d w h', w=self.input_size, h=self.input_size) # M, bs, 768, 28, 28
|
| 307 |
-
|
| 308 |
-
# Image Quality Score Regression
|
| 309 |
-
fusion_mal = self.fusion_mal(DOF).permute(0, 2, 1).contiguous() # bs, 28 * 28 768
|
| 310 |
-
IQ_feature = self.block(fusion_mal).permute(0, 2, 1).contiguous() # bs, 768, 28 * 28
|
| 311 |
-
IQ_feature = rearrange(IQ_feature, 'c d (w h) -> c d w h', w=self.input_size, h=self.input_size) # bs, 768, 28, 28
|
| 312 |
-
img_feature = self.cnn(IQ_feature).squeeze(-1).squeeze(-1).unsqueeze(1) # bs, 1, 128
|
| 313 |
-
|
| 314 |
-
# 分数线性变换为128维
|
| 315 |
-
# score_feature = self.score_projection(score) # bs, 128
|
| 316 |
-
score_feature = score.expand(-1, 128)
|
| 317 |
-
|
| 318 |
-
# img_feature 和 score_feature融合得到 funsion_feature
|
| 319 |
-
funsion_feature = self.i_p_fusion(torch.cat((img_feature, score_feature.unsqueeze(1)), dim=1)) # bs, 2, 128
|
| 320 |
-
funsion_feature = self.mlp(torch.mean(funsion_feature, dim=1)).unsqueeze(0) # 1, n, 128
|
| 321 |
-
|
| 322 |
-
print('Load Prompt For Testing.', funsion_feature.shape)
|
| 323 |
-
self.prompt_feature = funsion_feature.clone()
|
| 324 |
-
|
| 325 |
-
def expand(self, A):
|
| 326 |
-
A_expanded = A.unsqueeze(0).expand(A.size(0), -1, -1)
|
| 327 |
-
|
| 328 |
-
B = None
|
| 329 |
-
for index, i in enumerate(A_expanded):
|
| 330 |
-
rmv = torch.cat((i[:index], i[index + 1:])).unsqueeze(0)
|
| 331 |
-
if B is None:
|
| 332 |
-
B = rmv
|
| 333 |
-
else:
|
| 334 |
-
B = torch.cat((B, rmv), dim=0)
|
| 335 |
-
|
| 336 |
-
return B
|
| 337 |
-
|
| 338 |
-
def forward(self, x, score):
|
| 339 |
-
_x = self.vit(x)
|
| 340 |
-
x = self.extract_feature(self.save_output) # bs, 28 * 28, 768 * 4
|
| 341 |
-
self.save_output.outputs.clear()
|
| 342 |
-
|
| 343 |
-
x = x.permute(0, 2, 1).contiguous() # bs, 768 * 4, 28 * 28
|
| 344 |
-
x = rearrange(x, 'b (d n) (w h) -> b d n w h', d=4, n=self.dim_mlp, w=self.input_size, h=self.input_size) # bs, 4, 768, 28, 28
|
| 345 |
-
x = x.permute(1, 0, 2, 3, 4).contiguous() # 4, bs, 768, 28, 28
|
| 346 |
-
|
| 347 |
-
# Different Opinion Features (DOF)
|
| 348 |
-
DOF = torch.tensor([]).cuda()
|
| 349 |
-
for index, _ in enumerate(self.MALs):
|
| 350 |
-
DOF = torch.cat((DOF, self.MALs[index](x).unsqueeze(0)), dim=0)
|
| 351 |
-
DOF = rearrange(DOF, 'n c d (w h) -> n c d w h', w=self.input_size, h=self.input_size) # M, bs, 768, 28, 28
|
| 352 |
-
|
| 353 |
-
# Image Quality Score Regression
|
| 354 |
-
fusion_mal = self.fusion_mal(DOF).permute(0, 2, 1).contiguous() # bs, 28 * 28 768
|
| 355 |
-
IQ_feature = self.block(fusion_mal).permute(0, 2, 1).contiguous() # bs, 768, 28 * 28
|
| 356 |
-
IQ_feature = rearrange(IQ_feature, 'c d (w h) -> c d w h', w=self.input_size, h=self.input_size) # bs, 768, 28, 28
|
| 357 |
-
img_feature = self.cnn(IQ_feature).squeeze(-1).squeeze(-1).unsqueeze(1) # bs, 1, 128
|
| 358 |
-
|
| 359 |
-
# 分数线性变换为128维
|
| 360 |
-
# score_feature = self.score_projection(score) # bs, 128
|
| 361 |
-
score_feature = score.expand(-1, 128) # bs, 128
|
| 362 |
-
|
| 363 |
-
# img_feature 和 score_feature融合得到 funsion_feature
|
| 364 |
-
funsion_feature = self.i_p_fusion(torch.cat((img_feature.detach(), score_feature.unsqueeze(1).detach()), dim=1)) # bs, 2, 128
|
| 365 |
-
funsion_feature = self.mlp(torch.mean(funsion_feature, dim=1)) #bs, 128
|
| 366 |
-
funsion_feature = self.expand(funsion_feature) # bs, bs - 1, 128
|
| 367 |
-
|
| 368 |
-
fusion = self.score_block(torch.cat((img_feature, funsion_feature), dim=1)) # bs, n, 1
|
| 369 |
-
# iq_res = torch.mean(fusion, dim=1).view(-1)
|
| 370 |
-
iq_res = fusion[:, 0].view(-1)
|
| 371 |
-
|
| 372 |
-
# differ_fusion = self.diff_block(torch.cat((img_feature, funsion_feature), dim=1)) # bs, n, 1
|
| 373 |
-
# differ_iq_res = torch.mean(differ_fusion, dim=1).view(-1)
|
| 374 |
-
|
| 375 |
-
gt_res = score.view(-1)
|
| 376 |
-
# diff_gt_res = 1 - score.view(-1)
|
| 377 |
-
|
| 378 |
-
return (iq_res, 'differ_iq_res'), (gt_res, 'diff_gt_res')
|
| 379 |
-
|
| 380 |
-
|
| 381 |
-
if __name__ == '__main__':
|
| 382 |
-
in_feature = torch.zeros((10, 3, 224, 224)).cuda()
|
| 383 |
-
gt_feature = torch.tensor([[0, 100, 1], [0, 100, 2], [0, 100, 3], [0, 100, 4], [0, 100, 5], [0, 100, 6], [0, 100, 7], [0, 100, 8], [0, 100, 9], [0, 100, 10]], dtype=torch.float).cuda()
|
| 384 |
-
model = MoNet().cuda()
|
| 385 |
-
|
| 386 |
-
iq_res, gt_res = model(in_feature, gt_feature)
|
| 387 |
-
|
| 388 |
-
print(iq_res.shape)
|
| 389 |
-
print(gt_res.shape)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
PromptIQA/models/monet_wo_prompt.py
DELETED
|
@@ -1,392 +0,0 @@
|
|
| 1 |
-
"""
|
| 2 |
-
The completion for Mean-opinion Network(MoNet)
|
| 3 |
-
"""
|
| 4 |
-
import torch
|
| 5 |
-
import torch.nn as nn
|
| 6 |
-
import timm
|
| 7 |
-
|
| 8 |
-
from timm.models.vision_transformer import Block
|
| 9 |
-
from einops import rearrange
|
| 10 |
-
from itertools import combinations
|
| 11 |
-
|
| 12 |
-
from tqdm import tqdm
|
| 13 |
-
import os
|
| 14 |
-
# os.environ['CUDA_VISIBLE_DEVICES'] = '7'
|
| 15 |
-
|
| 16 |
-
class Attention_Block(nn.Module):
|
| 17 |
-
def __init__(self, dim, drop=0.1):
|
| 18 |
-
super().__init__()
|
| 19 |
-
self.c_q = nn.Linear(dim, dim)
|
| 20 |
-
self.c_k = nn.Linear(dim, dim)
|
| 21 |
-
self.c_v = nn.Linear(dim, dim)
|
| 22 |
-
self.norm_fact = dim ** -0.5
|
| 23 |
-
self.softmax = nn.Softmax(dim=-1)
|
| 24 |
-
self.proj_drop = nn.Dropout(drop)
|
| 25 |
-
|
| 26 |
-
def forward(self, x):
|
| 27 |
-
_x = x
|
| 28 |
-
B, C, N = x.shape
|
| 29 |
-
q = self.c_q(x)
|
| 30 |
-
k = self.c_k(x)
|
| 31 |
-
v = self.c_v(x)
|
| 32 |
-
|
| 33 |
-
attn = q @ k.transpose(-2, -1) * self.norm_fact
|
| 34 |
-
attn = self.softmax(attn)
|
| 35 |
-
x = (attn @ v).transpose(1, 2).reshape(B, C, N)
|
| 36 |
-
x = self.proj_drop(x)
|
| 37 |
-
x = x + _x
|
| 38 |
-
return x
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
class Self_Attention(nn.Module):
|
| 42 |
-
""" Self attention Layer"""
|
| 43 |
-
|
| 44 |
-
def __init__(self, in_dim):
|
| 45 |
-
super(Self_Attention, self).__init__()
|
| 46 |
-
|
| 47 |
-
self.qConv = nn.Conv2d(in_channels=in_dim, out_channels=in_dim // 8, kernel_size=1)
|
| 48 |
-
self.kConv = nn.Conv2d(in_channels=in_dim, out_channels=in_dim // 8, kernel_size=1)
|
| 49 |
-
self.vConv = nn.Conv2d(in_channels=in_dim, out_channels=in_dim, kernel_size=1)
|
| 50 |
-
self.gamma = nn.Parameter(torch.zeros(1))
|
| 51 |
-
|
| 52 |
-
self.softmax = nn.Softmax(dim=-1)
|
| 53 |
-
|
| 54 |
-
def forward(self, inFeature):
|
| 55 |
-
bs, C, w, h = inFeature.size()
|
| 56 |
-
|
| 57 |
-
proj_query = self.qConv(inFeature).view(bs, -1, w * h).permute(0, 2, 1).contiguous()
|
| 58 |
-
proj_key = self.kConv(inFeature).view(bs, -1, w * h)
|
| 59 |
-
energy = torch.bmm(proj_query, proj_key)
|
| 60 |
-
attention = self.softmax(energy)
|
| 61 |
-
proj_value = self.vConv(inFeature).view(bs, -1, w * h)
|
| 62 |
-
|
| 63 |
-
out = torch.bmm(proj_value, attention.permute(0, 2, 1).contiguous())
|
| 64 |
-
out = out.view(bs, C, w, h)
|
| 65 |
-
|
| 66 |
-
out = self.gamma * out + inFeature
|
| 67 |
-
|
| 68 |
-
return out
|
| 69 |
-
|
| 70 |
-
|
| 71 |
-
class MAL(nn.Module):
|
| 72 |
-
"""
|
| 73 |
-
Multi-view Attention Learning (MAL) module
|
| 74 |
-
"""
|
| 75 |
-
|
| 76 |
-
def __init__(self, in_dim=768, feature_num=4, feature_size=28):
|
| 77 |
-
super().__init__()
|
| 78 |
-
|
| 79 |
-
self.channel_attention = Attention_Block(in_dim * feature_num) # Channel-wise self attention
|
| 80 |
-
self.feature_attention = Attention_Block(feature_size ** 2 * feature_num) # Pixel-wise self attention
|
| 81 |
-
|
| 82 |
-
# Self attention module for each input feature
|
| 83 |
-
self.attention_module = nn.ModuleList()
|
| 84 |
-
for _ in range(feature_num):
|
| 85 |
-
self.attention_module.append(Self_Attention(in_dim))
|
| 86 |
-
|
| 87 |
-
self.feature_num = feature_num
|
| 88 |
-
self.in_dim = in_dim
|
| 89 |
-
|
| 90 |
-
def forward(self, features):
|
| 91 |
-
feature = torch.tensor([]).cuda()
|
| 92 |
-
for index, _ in enumerate(features):
|
| 93 |
-
feature = torch.cat((feature, self.attention_module[index](features[index]).unsqueeze(0)), dim=0)
|
| 94 |
-
features = feature
|
| 95 |
-
|
| 96 |
-
input_tensor = rearrange(features, 'n b c w h -> b (n c) (w h)') # bs, 768 * feature_num, 28 * 28
|
| 97 |
-
bs, _, _ = input_tensor.shape # [2, 3072, 784]
|
| 98 |
-
|
| 99 |
-
in_feature = rearrange(input_tensor, 'b (w c) h -> b w (c h)', w=self.in_dim,
|
| 100 |
-
c=self.feature_num) # bs, 768, 28 * 28 * feature_num
|
| 101 |
-
feature_weight_sum = self.feature_attention(in_feature) # bs, 768, 768
|
| 102 |
-
|
| 103 |
-
in_channel = input_tensor.permute(0, 2, 1).contiguous() # bs, 28 * 28, 768 * feature_num
|
| 104 |
-
channel_weight_sum = self.channel_attention(in_channel) # bs, 28 * 28, 28 * 28
|
| 105 |
-
|
| 106 |
-
weight_sum_res = (rearrange(feature_weight_sum, 'b w (c h) -> b (w c) h', w=self.in_dim,
|
| 107 |
-
c=self.feature_num) + channel_weight_sum.permute(0, 2, 1).contiguous()) / 2 # [2, 3072, 784]
|
| 108 |
-
|
| 109 |
-
weight_sum_res = torch.mean(weight_sum_res.view(bs, self.feature_num, self.in_dim, -1), dim=1)
|
| 110 |
-
|
| 111 |
-
return weight_sum_res # bs, 768, 28 * 28
|
| 112 |
-
|
| 113 |
-
|
| 114 |
-
class SaveOutput:
|
| 115 |
-
def __init__(self):
|
| 116 |
-
self.outputs = []
|
| 117 |
-
|
| 118 |
-
def __call__(self, module, module_in, module_out):
|
| 119 |
-
self.outputs.append(module_out)
|
| 120 |
-
|
| 121 |
-
def clear(self):
|
| 122 |
-
self.outputs = []
|
| 123 |
-
|
| 124 |
-
# utils
|
| 125 |
-
@torch.no_grad()
|
| 126 |
-
def concat_all_gather(tensor):
|
| 127 |
-
"""
|
| 128 |
-
Performs all_gather operation on the provided tensors.
|
| 129 |
-
*** Warning ***: torch.distributed.all_gather has no gradient.
|
| 130 |
-
"""
|
| 131 |
-
tensors_gather = [
|
| 132 |
-
torch.ones_like(tensor) for _ in range(torch.distributed.get_world_size())
|
| 133 |
-
]
|
| 134 |
-
torch.distributed.all_gather(tensors_gather, tensor, async_op=False)
|
| 135 |
-
|
| 136 |
-
output = torch.cat(tensors_gather, dim=0)
|
| 137 |
-
return output
|
| 138 |
-
class Attention(nn.Module):
|
| 139 |
-
def __init__(self, dim, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0.):
|
| 140 |
-
super().__init__()
|
| 141 |
-
assert dim % num_heads == 0, 'dim should be divisible by num_heads'
|
| 142 |
-
self.num_heads = num_heads
|
| 143 |
-
head_dim = dim // num_heads
|
| 144 |
-
self.scale = head_dim ** -0.5
|
| 145 |
-
|
| 146 |
-
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
| 147 |
-
self.attn_drop = nn.Dropout(attn_drop)
|
| 148 |
-
self.proj = nn.Linear(dim, dim)
|
| 149 |
-
self.proj_drop = nn.Dropout(proj_drop)
|
| 150 |
-
|
| 151 |
-
def forward(self, x):
|
| 152 |
-
B, N, C = x.shape
|
| 153 |
-
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
|
| 154 |
-
q, k, v = qkv.unbind(0) # make torchscript happy (cannot use tensor as tuple)
|
| 155 |
-
|
| 156 |
-
attn = (q @ k.transpose(-2, -1)) * self.scale
|
| 157 |
-
attn = attn.softmax(dim=-1)
|
| 158 |
-
attn = self.attn_drop(attn)
|
| 159 |
-
|
| 160 |
-
x = (attn @ v).transpose(1, 2).reshape(B, N, C)
|
| 161 |
-
x = self.proj(x)
|
| 162 |
-
x = self.proj_drop(x)
|
| 163 |
-
return x
|
| 164 |
-
|
| 165 |
-
from functools import partial
|
| 166 |
-
class MoNet(nn.Module):
|
| 167 |
-
def __init__(self, patch_size=8, drop=0.1, dim_mlp=768, img_size=224):
|
| 168 |
-
super().__init__()
|
| 169 |
-
self.img_size = img_size
|
| 170 |
-
self.input_size = img_size // patch_size
|
| 171 |
-
self.dim_mlp = dim_mlp
|
| 172 |
-
|
| 173 |
-
self.vit = timm.create_model('vit_base_patch8_224', pretrained=True)
|
| 174 |
-
self.vit.norm = nn.Identity()
|
| 175 |
-
self.vit.head = nn.Identity()
|
| 176 |
-
|
| 177 |
-
self.save_output = SaveOutput()
|
| 178 |
-
|
| 179 |
-
# Register Hooks
|
| 180 |
-
hook_handles = []
|
| 181 |
-
for layer in self.vit.modules():
|
| 182 |
-
if isinstance(layer, Block):
|
| 183 |
-
handle = layer.register_forward_hook(self.save_output)
|
| 184 |
-
hook_handles.append(handle)
|
| 185 |
-
|
| 186 |
-
self.MALs = nn.ModuleList()
|
| 187 |
-
for _ in range(3):
|
| 188 |
-
self.MALs.append(MAL())
|
| 189 |
-
|
| 190 |
-
# Image Quality Score Regression
|
| 191 |
-
self.fusion_mal = MAL(feature_num=3)
|
| 192 |
-
self.block = Block(dim_mlp, 12)
|
| 193 |
-
self.cnn = nn.Sequential(
|
| 194 |
-
nn.Conv2d(dim_mlp, 256, 5),
|
| 195 |
-
nn.BatchNorm2d(256),
|
| 196 |
-
nn.ReLU(inplace=True),
|
| 197 |
-
nn.AvgPool2d((2, 2)),
|
| 198 |
-
nn.Conv2d(256, 128, 3),
|
| 199 |
-
nn.BatchNorm2d(128),
|
| 200 |
-
nn.ReLU(inplace=True),
|
| 201 |
-
nn.AvgPool2d((2, 2)),
|
| 202 |
-
nn.Conv2d(128, 128, 3),
|
| 203 |
-
nn.BatchNorm2d(128),
|
| 204 |
-
nn.ReLU(inplace=True),
|
| 205 |
-
nn.AvgPool2d((3, 3)),
|
| 206 |
-
)
|
| 207 |
-
|
| 208 |
-
# self.i_p_fusion = nn.Sequential(
|
| 209 |
-
# Block(128, 4),
|
| 210 |
-
# Block(128, 4),
|
| 211 |
-
# Block(128, 4),
|
| 212 |
-
# )
|
| 213 |
-
# self.mlp = nn.Sequential(
|
| 214 |
-
# nn.Linear(128, 64),
|
| 215 |
-
# nn.GELU(),
|
| 216 |
-
# nn.Linear(64, 128),
|
| 217 |
-
# )
|
| 218 |
-
|
| 219 |
-
dpr = [x.item() for x in torch.linspace(0, 0, 8)] # stochastic depth decay rule
|
| 220 |
-
self.blocks = nn.Sequential(*[
|
| 221 |
-
Block(
|
| 222 |
-
dim=128, num_heads=4, mlp_ratio=4, qkv_bias=True, drop=0,
|
| 223 |
-
attn_drop=0, drop_path=dpr[i], norm_layer=partial(nn.LayerNorm, eps=1e-6), act_layer=nn.GELU)
|
| 224 |
-
for i in range(8)])
|
| 225 |
-
self.norm = nn.LayerNorm(128)
|
| 226 |
-
|
| 227 |
-
self.score_block = nn.Sequential(
|
| 228 |
-
nn.Linear(128, 128 // 2),
|
| 229 |
-
nn.ReLU(),
|
| 230 |
-
nn.Dropout(drop),
|
| 231 |
-
nn.Linear(128 // 2, 1),
|
| 232 |
-
nn.Sigmoid()
|
| 233 |
-
)
|
| 234 |
-
|
| 235 |
-
self.prompt_feature = {}
|
| 236 |
-
|
| 237 |
-
@torch.no_grad()
|
| 238 |
-
def clear(self):
|
| 239 |
-
self.prompt_feature = {}
|
| 240 |
-
|
| 241 |
-
@torch.no_grad()
|
| 242 |
-
def inference(self, x, data_type):
|
| 243 |
-
# prompt_feature = self.prompt_feature[data_type] # 1, n, 128
|
| 244 |
-
|
| 245 |
-
_x = self.vit(x)
|
| 246 |
-
x = self.extract_feature(self.save_output) # bs, 28 * 28, 768 * 4
|
| 247 |
-
self.save_output.outputs.clear()
|
| 248 |
-
|
| 249 |
-
x = x.permute(0, 2, 1).contiguous() # bs, 768 * 4, 28 * 28
|
| 250 |
-
x = rearrange(x, 'b (d n) (w h) -> b d n w h', d=4, n=self.dim_mlp, w=self.input_size, h=self.input_size) # bs, 4, 768, 28, 28
|
| 251 |
-
x = x.permute(1, 0, 2, 3, 4).contiguous() # bs, 4, 768, 28 * 28
|
| 252 |
-
|
| 253 |
-
# Different Opinion Features (DOF)
|
| 254 |
-
DOF = torch.tensor([]).cuda()
|
| 255 |
-
for index, _ in enumerate(self.MALs):
|
| 256 |
-
DOF = torch.cat((DOF, self.MALs[index](x).unsqueeze(0)), dim=0)
|
| 257 |
-
DOF = rearrange(DOF, 'n c d (w h) -> n c d w h', w=self.input_size, h=self.input_size) # M, bs, 768, 28, 28
|
| 258 |
-
|
| 259 |
-
# Image Quality Score Regression
|
| 260 |
-
fusion_mal = self.fusion_mal(DOF).permute(0, 2, 1).contiguous() # bs, 28 * 28 768
|
| 261 |
-
IQ_feature = self.block(fusion_mal).permute(0, 2, 1).contiguous() # bs, 768, 28 * 28
|
| 262 |
-
IQ_feature = rearrange(IQ_feature, 'c d (w h) -> c d w h', w=self.input_size, h=self.input_size) # bs, 768, 28, 28
|
| 263 |
-
img_feature = self.cnn(IQ_feature).squeeze(-1).squeeze(-1).unsqueeze(1) # bs, 1, 128
|
| 264 |
-
|
| 265 |
-
# prompt_feature = prompt_feature.repeat(img_feature.shape[0], 1, 1) # bs, n, 128
|
| 266 |
-
# prompt_feature = self.prompt_fusion(prompt_feature) # bs, n, 128
|
| 267 |
-
|
| 268 |
-
fusion = self.blocks(img_feature) # bs, 2, 1
|
| 269 |
-
fusion = self.norm(fusion)
|
| 270 |
-
fusion = self.score_block(fusion)
|
| 271 |
-
|
| 272 |
-
# iq_res = torch.mean(fusion, dim=1).view(-1)
|
| 273 |
-
iq_res = fusion[:, 0].view(-1)
|
| 274 |
-
|
| 275 |
-
return iq_res
|
| 276 |
-
|
| 277 |
-
@torch.no_grad()
|
| 278 |
-
def check_prompt(self, data_type):
|
| 279 |
-
return data_type in self.prompt_feature
|
| 280 |
-
|
| 281 |
-
@torch.no_grad()
|
| 282 |
-
def forward_prompt(self, x, score, data_type):
|
| 283 |
-
pass
|
| 284 |
-
# if data_type in self.prompt_feature:
|
| 285 |
-
# return
|
| 286 |
-
# _x = self.vit(x)
|
| 287 |
-
# x = self.extract_feature(self.save_output) # bs, 28 * 28, 768 * 4
|
| 288 |
-
# self.save_output.outputs.clear()
|
| 289 |
-
|
| 290 |
-
# x = x.permute(0, 2, 1).contiguous() # bs, 768 * 4, 28 * 28
|
| 291 |
-
# x = rearrange(x, 'b (d n) (w h) -> b d n w h', d=4, n=self.dim_mlp, w=self.input_size, h=self.input_size) # bs, 4, 768, 28, 28
|
| 292 |
-
# x = x.permute(1, 0, 2, 3, 4).contiguous() # 4, bs, 768, 28, 28
|
| 293 |
-
|
| 294 |
-
# # Different Opinion Features (DOF)
|
| 295 |
-
# DOF = torch.tensor([]).cuda()
|
| 296 |
-
# for index, _ in enumerate(self.MALs):
|
| 297 |
-
# DOF = torch.cat((DOF, self.MALs[index](x).unsqueeze(0)), dim=0)
|
| 298 |
-
# DOF = rearrange(DOF, 'n c d (w h) -> n c d w h', w=self.input_size, h=self.input_size) # M, bs, 768, 28, 28
|
| 299 |
-
|
| 300 |
-
# # Image Quality Score Regression
|
| 301 |
-
# fusion_mal = self.fusion_mal(DOF).permute(0, 2, 1).contiguous() # bs, 28 * 28 768
|
| 302 |
-
# IQ_feature = self.block(fusion_mal).permute(0, 2, 1).contiguous() # bs, 768, 28 * 28
|
| 303 |
-
# IQ_feature = rearrange(IQ_feature, 'c d (w h) -> c d w h', w=self.input_size, h=self.input_size) # bs, 768, 28, 28
|
| 304 |
-
# img_feature = self.cnn(IQ_feature).squeeze(-1).squeeze(-1).unsqueeze(1) # bs, 1, 128
|
| 305 |
-
|
| 306 |
-
# # 分数线性变换为128维
|
| 307 |
-
# # score_feature = self.score_projection(score) # bs, 128
|
| 308 |
-
# score_feature = score.expand(-1, 128)
|
| 309 |
-
|
| 310 |
-
# # img_feature 和 score_feature融合得到 funsion_feature
|
| 311 |
-
# funsion_feature = self.i_p_fusion(torch.cat((img_feature, score_feature.unsqueeze(1)), dim=1)) # bs, 2, 128
|
| 312 |
-
# funsion_feature = self.mlp(torch.mean(funsion_feature, dim=1)).unsqueeze(0) # 1, n, 128
|
| 313 |
-
|
| 314 |
-
# # print('Load Prompt For Testing.', funsion_feature.shape)
|
| 315 |
-
# # self.prompt_feature = funsion_feature.clone()
|
| 316 |
-
# self.prompt_feature[data_type] = funsion_feature.clone()
|
| 317 |
-
|
| 318 |
-
def forward(self, x, score):
|
| 319 |
-
_x = self.vit(x)
|
| 320 |
-
x = self.extract_feature(self.save_output) # bs, 28 * 28, 768 * 4
|
| 321 |
-
self.save_output.outputs.clear()
|
| 322 |
-
|
| 323 |
-
x = x.permute(0, 2, 1).contiguous() # bs, 768 * 4, 28 * 28
|
| 324 |
-
x = rearrange(x, 'b (d n) (w h) -> b d n w h', d=4, n=self.dim_mlp, w=self.input_size, h=self.input_size) # bs, 4, 768, 28, 28
|
| 325 |
-
x = x.permute(1, 0, 2, 3, 4).contiguous() # 4, bs, 768, 28, 28
|
| 326 |
-
|
| 327 |
-
# Different Opinion Features (DOF)
|
| 328 |
-
DOF = torch.tensor([]).cuda()
|
| 329 |
-
for index, _ in enumerate(self.MALs):
|
| 330 |
-
DOF = torch.cat((DOF, self.MALs[index](x).unsqueeze(0)), dim=0)
|
| 331 |
-
DOF = rearrange(DOF, 'n c d (w h) -> n c d w h', w=self.input_size, h=self.input_size) # M, bs, 768, 28, 28
|
| 332 |
-
|
| 333 |
-
# Image Quality Score Regression
|
| 334 |
-
fusion_mal = self.fusion_mal(DOF).permute(0, 2, 1).contiguous() # bs, 28 * 28 768
|
| 335 |
-
IQ_feature = self.block(fusion_mal).permute(0, 2, 1).contiguous() # bs, 768, 28 * 28
|
| 336 |
-
IQ_feature = rearrange(IQ_feature, 'c d (w h) -> c d w h', w=self.input_size, h=self.input_size) # bs, 768, 28, 28
|
| 337 |
-
img_feature = self.cnn(IQ_feature).squeeze(-1).squeeze(-1).unsqueeze(1) # bs, 1, 128
|
| 338 |
-
|
| 339 |
-
# 分数线性变换为128维
|
| 340 |
-
# score_feature = self.score_projection(score) # bs, 128
|
| 341 |
-
# score_feature = score.expand(-1, 128) # bs, 128
|
| 342 |
-
|
| 343 |
-
# # img_feature 和 score_feature融合得到 funsion_feature
|
| 344 |
-
# funsion_feature = self.i_p_fusion(torch.cat((img_feature, score_feature.unsqueeze(1)), dim=1)) # bs, 2, 128
|
| 345 |
-
# funsion_feature = self.mlp(torch.mean(funsion_feature, dim=1)) #bs, 128
|
| 346 |
-
# funsion_feature = self.expand(funsion_feature) # bs, bs - 1, 128
|
| 347 |
-
# funsion_feature = self.prompt_fusion(funsion_feature) # bs, bs - 1, 128
|
| 348 |
-
|
| 349 |
-
fusion = self.blocks(img_feature) # bs, 2, 1
|
| 350 |
-
fusion = self.norm(fusion)
|
| 351 |
-
fusion = self.score_block(fusion)
|
| 352 |
-
# iq_res = torch.mean(fusion, dim=1).view(-1)
|
| 353 |
-
iq_res = fusion[:, 0].view(-1)
|
| 354 |
-
|
| 355 |
-
# differ_fusion = self.diff_block(torch.cat((img_feature, funsion_feature), dim=1)) # bs, n, 1
|
| 356 |
-
# differ_iq_res = torch.mean(differ_fusion, dim=1).view(-1)
|
| 357 |
-
|
| 358 |
-
gt_res = score.view(-1)
|
| 359 |
-
# diff_gt_res = 1 - score.view(-1)
|
| 360 |
-
|
| 361 |
-
return (iq_res, 'differ_iq_res'), (gt_res, 'diff_gt_res')
|
| 362 |
-
|
| 363 |
-
def extract_feature(self, save_output, block_index=[2, 5, 8, 11]):
|
| 364 |
-
x1 = save_output.outputs[block_index[0]][:, 1:]
|
| 365 |
-
x2 = save_output.outputs[block_index[1]][:, 1:]
|
| 366 |
-
x3 = save_output.outputs[block_index[2]][:, 1:]
|
| 367 |
-
x4 = save_output.outputs[block_index[3]][:, 1:]
|
| 368 |
-
x = torch.cat((x1, x2, x3, x4), dim=2)
|
| 369 |
-
return x
|
| 370 |
-
|
| 371 |
-
def expand(self, A):
|
| 372 |
-
A_expanded = A.unsqueeze(0).expand(A.size(0), -1, -1)
|
| 373 |
-
|
| 374 |
-
B = None
|
| 375 |
-
for index, i in enumerate(A_expanded):
|
| 376 |
-
rmv = torch.cat((i[:index], i[index + 1:])).unsqueeze(0)
|
| 377 |
-
if B is None:
|
| 378 |
-
B = rmv
|
| 379 |
-
else:
|
| 380 |
-
B = torch.cat((B, rmv), dim=0)
|
| 381 |
-
|
| 382 |
-
return B
|
| 383 |
-
|
| 384 |
-
if __name__ == '__main__':
|
| 385 |
-
in_feature = torch.zeros((2, 3, 224, 224)).cuda()
|
| 386 |
-
gt_feature = torch.tensor([[0, 100, 1], [0, 100, 2]], dtype=torch.float).cuda()
|
| 387 |
-
model = MoNet().cuda()
|
| 388 |
-
|
| 389 |
-
iq_res, gt_res = model(in_feature, gt_feature)
|
| 390 |
-
|
| 391 |
-
print(iq_res)
|
| 392 |
-
print(gt_res.shape)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
PromptIQA/models/{monet.py → promptiqa.py}
RENAMED
|
@@ -6,10 +6,8 @@ import torch.nn as nn
|
|
| 6 |
import timm
|
| 7 |
|
| 8 |
from timm.models.vision_transformer import Block
|
|
|
|
| 9 |
from einops import rearrange
|
| 10 |
-
from itertools import combinations
|
| 11 |
-
|
| 12 |
-
from tqdm import tqdm
|
| 13 |
|
| 14 |
class Attention_Block(nn.Module):
|
| 15 |
def __init__(self, dim, drop=0.1):
|
|
@@ -119,20 +117,6 @@ class SaveOutput:
|
|
| 119 |
def clear(self):
|
| 120 |
self.outputs = []
|
| 121 |
|
| 122 |
-
# utils
|
| 123 |
-
@torch.no_grad()
|
| 124 |
-
def concat_all_gather(tensor):
|
| 125 |
-
"""
|
| 126 |
-
Performs all_gather operation on the provided tensors.
|
| 127 |
-
*** Warning ***: torch.distributed.all_gather has no gradient.
|
| 128 |
-
"""
|
| 129 |
-
tensors_gather = [
|
| 130 |
-
torch.ones_like(tensor) for _ in range(torch.distributed.get_world_size())
|
| 131 |
-
]
|
| 132 |
-
torch.distributed.all_gather(tensors_gather, tensor, async_op=False)
|
| 133 |
-
|
| 134 |
-
output = torch.cat(tensors_gather, dim=0)
|
| 135 |
-
return output
|
| 136 |
class Attention(nn.Module):
|
| 137 |
def __init__(self, dim, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0.):
|
| 138 |
super().__init__()
|
|
@@ -160,8 +144,7 @@ class Attention(nn.Module):
|
|
| 160 |
x = self.proj_drop(x)
|
| 161 |
return x
|
| 162 |
|
| 163 |
-
|
| 164 |
-
class MoNet(nn.Module):
|
| 165 |
def __init__(self, patch_size=8, drop=0.1, dim_mlp=768, img_size=224):
|
| 166 |
super().__init__()
|
| 167 |
self.img_size = img_size
|
|
@@ -273,15 +256,10 @@ class MoNet(nn.Module):
|
|
| 273 |
fusion = self.norm(fusion)
|
| 274 |
fusion = self.score_block(fusion)
|
| 275 |
|
| 276 |
-
# iq_res = torch.mean(fusion, dim=1).view(-1)
|
| 277 |
iq_res = fusion[:, 0].view(-1)
|
| 278 |
|
| 279 |
return iq_res
|
| 280 |
|
| 281 |
-
@torch.no_grad()
|
| 282 |
-
def check_prompt(self, data_type):
|
| 283 |
-
return data_type in self.prompt_feature
|
| 284 |
-
|
| 285 |
@torch.no_grad()
|
| 286 |
def forward_prompt(self, x, score, data_type):
|
| 287 |
_x = self.vit(x)
|
|
@@ -304,63 +282,13 @@ class MoNet(nn.Module):
|
|
| 304 |
IQ_feature = rearrange(IQ_feature, 'c d (w h) -> c d w h', w=self.input_size, h=self.input_size) # bs, 768, 28, 28
|
| 305 |
img_feature = self.cnn(IQ_feature).squeeze(-1).squeeze(-1).unsqueeze(1) # bs, 1, 128
|
| 306 |
|
| 307 |
-
# 分数线性变换为128维
|
| 308 |
-
# score_feature = self.score_projection(score) # bs, 128
|
| 309 |
score_feature = score.expand(-1, 128)
|
| 310 |
|
| 311 |
-
# img_feature 和 score_feature融合得到 funsion_feature
|
| 312 |
funsion_feature = self.i_p_fusion(torch.cat((img_feature, score_feature.unsqueeze(1)), dim=1)) # bs, 2, 128
|
| 313 |
funsion_feature = self.mlp(torch.mean(funsion_feature, dim=1)).unsqueeze(0) # 1, n, 128
|
| 314 |
|
| 315 |
-
# print('Load Prompt For Testing.', funsion_feature.shape)
|
| 316 |
-
# self.prompt_feature = funsion_feature.clone()
|
| 317 |
self.prompt_feature[data_type] = funsion_feature.clone()
|
| 318 |
|
| 319 |
-
def forward(self, x, score):
|
| 320 |
-
_x = self.vit(x)
|
| 321 |
-
x = self.extract_feature(self.save_output) # bs, 28 * 28, 768 * 4
|
| 322 |
-
self.save_output.outputs.clear()
|
| 323 |
-
|
| 324 |
-
x = x.permute(0, 2, 1).contiguous() # bs, 768 * 4, 28 * 28
|
| 325 |
-
x = rearrange(x, 'b (d n) (w h) -> b d n w h', d=4, n=self.dim_mlp, w=self.input_size, h=self.input_size) # bs, 4, 768, 28, 28
|
| 326 |
-
x = x.permute(1, 0, 2, 3, 4).contiguous() # 4, bs, 768, 28, 28
|
| 327 |
-
|
| 328 |
-
# Different Opinion Features (DOF)
|
| 329 |
-
DOF = torch.tensor([]).cuda()
|
| 330 |
-
for index, _ in enumerate(self.MALs):
|
| 331 |
-
DOF = torch.cat((DOF, self.MALs[index](x).unsqueeze(0)), dim=0)
|
| 332 |
-
DOF = rearrange(DOF, 'n c d (w h) -> n c d w h', w=self.input_size, h=self.input_size) # M, bs, 768, 28, 28
|
| 333 |
-
|
| 334 |
-
# Image Quality Score Regression
|
| 335 |
-
fusion_mal = self.fusion_mal(DOF).permute(0, 2, 1).contiguous() # bs, 28 * 28 768
|
| 336 |
-
IQ_feature = self.block(fusion_mal).permute(0, 2, 1).contiguous() # bs, 768, 28 * 28
|
| 337 |
-
IQ_feature = rearrange(IQ_feature, 'c d (w h) -> c d w h', w=self.input_size, h=self.input_size) # bs, 768, 28, 28
|
| 338 |
-
img_feature = self.cnn(IQ_feature).squeeze(-1).squeeze(-1).unsqueeze(1) # bs, 1, 128
|
| 339 |
-
|
| 340 |
-
# 分数线性变换为128维
|
| 341 |
-
# score_feature = self.score_projection(score) # bs, 128
|
| 342 |
-
score_feature = score.expand(-1, 128) # bs, 128
|
| 343 |
-
|
| 344 |
-
# img_feature 和 score_feature融合得到 funsion_feature
|
| 345 |
-
funsion_feature = self.i_p_fusion(torch.cat((img_feature, score_feature.unsqueeze(1)), dim=1)) # bs, 2, 128
|
| 346 |
-
funsion_feature = self.mlp(torch.mean(funsion_feature, dim=1)) #bs, 128
|
| 347 |
-
funsion_feature = self.expand(funsion_feature) # bs, bs - 1, 128
|
| 348 |
-
funsion_feature = self.prompt_fusion(funsion_feature) # bs, bs - 1, 128
|
| 349 |
-
|
| 350 |
-
fusion = self.blocks(torch.cat((img_feature, funsion_feature), dim=1)) # bs, 2, 1
|
| 351 |
-
fusion = self.norm(fusion)
|
| 352 |
-
fusion = self.score_block(fusion)
|
| 353 |
-
# iq_res = torch.mean(fusion, dim=1).view(-1)
|
| 354 |
-
iq_res = fusion[:, 0].view(-1)
|
| 355 |
-
|
| 356 |
-
# differ_fusion = self.diff_block(torch.cat((img_feature, funsion_feature), dim=1)) # bs, n, 1
|
| 357 |
-
# differ_iq_res = torch.mean(differ_fusion, dim=1).view(-1)
|
| 358 |
-
|
| 359 |
-
gt_res = score.view(-1)
|
| 360 |
-
# diff_gt_res = 1 - score.view(-1)
|
| 361 |
-
|
| 362 |
-
return (iq_res, 'differ_iq_res'), (gt_res, 'diff_gt_res')
|
| 363 |
-
|
| 364 |
def extract_feature(self, save_output, block_index=[2, 5, 8, 11]):
|
| 365 |
x1 = save_output.outputs[block_index[0]][:, 1:]
|
| 366 |
x2 = save_output.outputs[block_index[1]][:, 1:]
|
|
@@ -381,13 +309,3 @@ class MoNet(nn.Module):
|
|
| 381 |
B = torch.cat((B, rmv), dim=0)
|
| 382 |
|
| 383 |
return B
|
| 384 |
-
|
| 385 |
-
if __name__ == '__main__':
|
| 386 |
-
in_feature = torch.zeros((10, 3, 224, 224)).cuda()
|
| 387 |
-
gt_feature = torch.tensor([[1], [2], [3], [4], [5], [6], [7], [8], [9], [10]], dtype=torch.float).cuda()
|
| 388 |
-
model = MoNet().cuda()
|
| 389 |
-
|
| 390 |
-
iq_res, gt_res = model(in_feature, gt_feature)
|
| 391 |
-
|
| 392 |
-
print(iq_res.shape)
|
| 393 |
-
print(gt_res.shape)
|
|
|
|
| 6 |
import timm
|
| 7 |
|
| 8 |
from timm.models.vision_transformer import Block
|
| 9 |
+
from functools import partial
|
| 10 |
from einops import rearrange
|
|
|
|
|
|
|
|
|
|
| 11 |
|
| 12 |
class Attention_Block(nn.Module):
|
| 13 |
def __init__(self, dim, drop=0.1):
|
|
|
|
| 117 |
def clear(self):
|
| 118 |
self.outputs = []
|
| 119 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 120 |
class Attention(nn.Module):
|
| 121 |
def __init__(self, dim, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0.):
|
| 122 |
super().__init__()
|
|
|
|
| 144 |
x = self.proj_drop(x)
|
| 145 |
return x
|
| 146 |
|
| 147 |
+
class PromptIQA(nn.Module):
|
|
|
|
| 148 |
def __init__(self, patch_size=8, drop=0.1, dim_mlp=768, img_size=224):
|
| 149 |
super().__init__()
|
| 150 |
self.img_size = img_size
|
|
|
|
| 256 |
fusion = self.norm(fusion)
|
| 257 |
fusion = self.score_block(fusion)
|
| 258 |
|
|
|
|
| 259 |
iq_res = fusion[:, 0].view(-1)
|
| 260 |
|
| 261 |
return iq_res
|
| 262 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 263 |
@torch.no_grad()
|
| 264 |
def forward_prompt(self, x, score, data_type):
|
| 265 |
_x = self.vit(x)
|
|
|
|
| 282 |
IQ_feature = rearrange(IQ_feature, 'c d (w h) -> c d w h', w=self.input_size, h=self.input_size) # bs, 768, 28, 28
|
| 283 |
img_feature = self.cnn(IQ_feature).squeeze(-1).squeeze(-1).unsqueeze(1) # bs, 1, 128
|
| 284 |
|
|
|
|
|
|
|
| 285 |
score_feature = score.expand(-1, 128)
|
| 286 |
|
|
|
|
| 287 |
funsion_feature = self.i_p_fusion(torch.cat((img_feature, score_feature.unsqueeze(1)), dim=1)) # bs, 2, 128
|
| 288 |
funsion_feature = self.mlp(torch.mean(funsion_feature, dim=1)).unsqueeze(0) # 1, n, 128
|
| 289 |
|
|
|
|
|
|
|
| 290 |
self.prompt_feature[data_type] = funsion_feature.clone()
|
| 291 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 292 |
def extract_feature(self, save_output, block_index=[2, 5, 8, 11]):
|
| 293 |
x1 = save_output.outputs[block_index[0]][:, 1:]
|
| 294 |
x2 = save_output.outputs[block_index[1]][:, 1:]
|
|
|
|
| 309 |
B = torch.cat((B, rmv), dim=0)
|
| 310 |
|
| 311 |
return B
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
PromptIQA/models/vit_base.py
DELETED
|
@@ -1,402 +0,0 @@
|
|
| 1 |
-
"""
|
| 2 |
-
The completion for Mean-opinion Network(MoNet)
|
| 3 |
-
"""
|
| 4 |
-
import torch
|
| 5 |
-
import torch.nn as nn
|
| 6 |
-
import timm
|
| 7 |
-
|
| 8 |
-
from timm.models.vision_transformer import Block
|
| 9 |
-
from einops import rearrange
|
| 10 |
-
from itertools import combinations
|
| 11 |
-
|
| 12 |
-
from tqdm import tqdm
|
| 13 |
-
|
| 14 |
-
|
| 15 |
-
class Attention_Block(nn.Module):
|
| 16 |
-
def __init__(self, dim, drop=0.1):
|
| 17 |
-
super().__init__()
|
| 18 |
-
self.c_q = nn.Linear(dim, dim)
|
| 19 |
-
self.c_k = nn.Linear(dim, dim)
|
| 20 |
-
self.c_v = nn.Linear(dim, dim)
|
| 21 |
-
self.norm_fact = dim ** -0.5
|
| 22 |
-
self.softmax = nn.Softmax(dim=-1)
|
| 23 |
-
self.proj_drop = nn.Dropout(drop)
|
| 24 |
-
|
| 25 |
-
def forward(self, x):
|
| 26 |
-
_x = x
|
| 27 |
-
B, C, N = x.shape
|
| 28 |
-
q = self.c_q(x)
|
| 29 |
-
k = self.c_k(x)
|
| 30 |
-
v = self.c_v(x)
|
| 31 |
-
|
| 32 |
-
attn = q @ k.transpose(-2, -1) * self.norm_fact
|
| 33 |
-
attn = self.softmax(attn)
|
| 34 |
-
x = (attn @ v).transpose(1, 2).reshape(B, C, N)
|
| 35 |
-
x = self.proj_drop(x)
|
| 36 |
-
x = x + _x
|
| 37 |
-
return x
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
class Self_Attention(nn.Module):
|
| 41 |
-
""" Self attention Layer"""
|
| 42 |
-
|
| 43 |
-
def __init__(self, in_dim):
|
| 44 |
-
super(Self_Attention, self).__init__()
|
| 45 |
-
|
| 46 |
-
self.qConv = nn.Conv2d(in_channels=in_dim, out_channels=in_dim // 8, kernel_size=1)
|
| 47 |
-
self.kConv = nn.Conv2d(in_channels=in_dim, out_channels=in_dim // 8, kernel_size=1)
|
| 48 |
-
self.vConv = nn.Conv2d(in_channels=in_dim, out_channels=in_dim, kernel_size=1)
|
| 49 |
-
self.gamma = nn.Parameter(torch.zeros(1))
|
| 50 |
-
|
| 51 |
-
self.softmax = nn.Softmax(dim=-1)
|
| 52 |
-
|
| 53 |
-
def forward(self, inFeature):
|
| 54 |
-
bs, C, w, h = inFeature.size()
|
| 55 |
-
|
| 56 |
-
proj_query = self.qConv(inFeature).view(bs, -1, w * h).permute(0, 2, 1).contiguous()
|
| 57 |
-
proj_key = self.kConv(inFeature).view(bs, -1, w * h)
|
| 58 |
-
energy = torch.bmm(proj_query, proj_key)
|
| 59 |
-
attention = self.softmax(energy)
|
| 60 |
-
proj_value = self.vConv(inFeature).view(bs, -1, w * h)
|
| 61 |
-
|
| 62 |
-
out = torch.bmm(proj_value, attention.permute(0, 2, 1).contiguous())
|
| 63 |
-
out = out.view(bs, C, w, h)
|
| 64 |
-
|
| 65 |
-
out = self.gamma * out + inFeature
|
| 66 |
-
|
| 67 |
-
return out
|
| 68 |
-
|
| 69 |
-
|
| 70 |
-
class three_cnn(nn.Module):
|
| 71 |
-
def __init__(self, in_dim) -> None:
|
| 72 |
-
super().__init__()
|
| 73 |
-
|
| 74 |
-
self.three_cnn = nn.Sequential(
|
| 75 |
-
nn.Conv2d(in_dim, in_dim // 2, kernel_size=3, padding=1),
|
| 76 |
-
nn.ReLU(inplace=True),
|
| 77 |
-
nn.Conv2d(in_dim // 2, in_dim // 2, kernel_size=3, padding=1),
|
| 78 |
-
nn.ReLU(inplace=True),
|
| 79 |
-
nn.Conv2d(in_dim // 2, in_dim, kernel_size=3, padding=1),
|
| 80 |
-
nn.ReLU(inplace=True),
|
| 81 |
-
)
|
| 82 |
-
|
| 83 |
-
def forward(self, input):
|
| 84 |
-
return self.three_cnn(input)
|
| 85 |
-
|
| 86 |
-
|
| 87 |
-
class MAL(nn.Module):
|
| 88 |
-
def __init__(self, in_dim=768, feature_num=4, feature_size=28):
|
| 89 |
-
super().__init__()
|
| 90 |
-
self.attention_module = nn.ModuleList()
|
| 91 |
-
for i in range(feature_num):
|
| 92 |
-
self.attention_module.append(three_cnn(in_dim))
|
| 93 |
-
|
| 94 |
-
self.feature_num = feature_num
|
| 95 |
-
self.in_dim = in_dim
|
| 96 |
-
self.feature_size = feature_size
|
| 97 |
-
|
| 98 |
-
def forward(self, features):
|
| 99 |
-
feature = torch.tensor([]).cuda()
|
| 100 |
-
for index, _ in enumerate(features):
|
| 101 |
-
feature = torch.cat((feature, self.attention_module[index](features[index]).unsqueeze(1)), dim=1)
|
| 102 |
-
feature = torch.mean(feature, dim=1)
|
| 103 |
-
features = feature.view(-1, self.in_dim, self.feature_size * self.feature_size)
|
| 104 |
-
|
| 105 |
-
return features # bs, 768, 28 * 28
|
| 106 |
-
|
| 107 |
-
|
| 108 |
-
class SaveOutput:
|
| 109 |
-
def __init__(self):
|
| 110 |
-
self.outputs = []
|
| 111 |
-
|
| 112 |
-
def __call__(self, module, module_in, module_out):
|
| 113 |
-
self.outputs.append(module_out)
|
| 114 |
-
|
| 115 |
-
def clear(self):
|
| 116 |
-
self.outputs = []
|
| 117 |
-
|
| 118 |
-
|
| 119 |
-
# utils
|
| 120 |
-
@torch.no_grad()
|
| 121 |
-
def concat_all_gather(tensor):
|
| 122 |
-
"""
|
| 123 |
-
Performs all_gather operation on the provided tensors.
|
| 124 |
-
*** Warning ***: torch.distributed.all_gather has no gradient.
|
| 125 |
-
"""
|
| 126 |
-
tensors_gather = [
|
| 127 |
-
torch.ones_like(tensor) for _ in range(torch.distributed.get_world_size())
|
| 128 |
-
]
|
| 129 |
-
torch.distributed.all_gather(tensors_gather, tensor, async_op=False)
|
| 130 |
-
|
| 131 |
-
output = torch.cat(tensors_gather, dim=0)
|
| 132 |
-
return output
|
| 133 |
-
|
| 134 |
-
|
| 135 |
-
class Attention(nn.Module):
|
| 136 |
-
def __init__(self, dim, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0.):
|
| 137 |
-
super().__init__()
|
| 138 |
-
assert dim % num_heads == 0, 'dim should be divisible by num_heads'
|
| 139 |
-
self.num_heads = num_heads
|
| 140 |
-
head_dim = dim // num_heads
|
| 141 |
-
self.scale = head_dim ** -0.5
|
| 142 |
-
|
| 143 |
-
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
| 144 |
-
self.attn_drop = nn.Dropout(attn_drop)
|
| 145 |
-
self.proj = nn.Linear(dim, dim)
|
| 146 |
-
self.proj_drop = nn.Dropout(proj_drop)
|
| 147 |
-
|
| 148 |
-
def forward(self, x):
|
| 149 |
-
B, N, C = x.shape
|
| 150 |
-
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
|
| 151 |
-
q, k, v = qkv.unbind(0) # make torchscript happy (cannot use tensor as tuple)
|
| 152 |
-
|
| 153 |
-
attn = (q @ k.transpose(-2, -1)) * self.scale
|
| 154 |
-
attn = attn.softmax(dim=-1)
|
| 155 |
-
attn = self.attn_drop(attn)
|
| 156 |
-
|
| 157 |
-
x = (attn @ v).transpose(1, 2).reshape(B, N, C)
|
| 158 |
-
x = self.proj(x)
|
| 159 |
-
x = self.proj_drop(x)
|
| 160 |
-
return x
|
| 161 |
-
|
| 162 |
-
|
| 163 |
-
from functools import partial
|
| 164 |
-
|
| 165 |
-
|
| 166 |
-
class MoNet(nn.Module):
|
| 167 |
-
def __init__(self, patch_size=8, drop=0.1, dim_mlp=768, img_size=224):
|
| 168 |
-
super().__init__()
|
| 169 |
-
self.img_size = img_size
|
| 170 |
-
self.input_size = img_size // patch_size
|
| 171 |
-
self.dim_mlp = dim_mlp
|
| 172 |
-
|
| 173 |
-
self.vit = timm.create_model('vit_base_patch8_224', pretrained=True)
|
| 174 |
-
self.vit.norm = nn.Identity()
|
| 175 |
-
self.vit.head = nn.Identity()
|
| 176 |
-
|
| 177 |
-
self.save_output = SaveOutput()
|
| 178 |
-
|
| 179 |
-
# Register Hooks
|
| 180 |
-
hook_handles = []
|
| 181 |
-
for layer in self.vit.modules():
|
| 182 |
-
if isinstance(layer, Block):
|
| 183 |
-
handle = layer.register_forward_hook(self.save_output)
|
| 184 |
-
hook_handles.append(handle)
|
| 185 |
-
|
| 186 |
-
self.MALs = nn.ModuleList()
|
| 187 |
-
for _ in range(1):
|
| 188 |
-
self.MALs.append(MAL())
|
| 189 |
-
|
| 190 |
-
# Image Quality Score Regression
|
| 191 |
-
self.fusion_mal = MAL(feature_num=1)
|
| 192 |
-
self.block = Block(dim_mlp, 12)
|
| 193 |
-
self.cnn = nn.Sequential(
|
| 194 |
-
nn.Conv2d(dim_mlp, 256, 5),
|
| 195 |
-
nn.BatchNorm2d(256),
|
| 196 |
-
nn.ReLU(inplace=True),
|
| 197 |
-
nn.AvgPool2d((2, 2)),
|
| 198 |
-
nn.Conv2d(256, 128, 3),
|
| 199 |
-
nn.BatchNorm2d(128),
|
| 200 |
-
nn.ReLU(inplace=True),
|
| 201 |
-
nn.AvgPool2d((2, 2)),
|
| 202 |
-
nn.Conv2d(128, 128, 3),
|
| 203 |
-
nn.BatchNorm2d(128),
|
| 204 |
-
nn.ReLU(inplace=True),
|
| 205 |
-
nn.AvgPool2d((3, 3)),
|
| 206 |
-
)
|
| 207 |
-
|
| 208 |
-
self.i_p_fusion = nn.Sequential(
|
| 209 |
-
Block(128, 4),
|
| 210 |
-
Block(128, 4),
|
| 211 |
-
Block(128, 4),
|
| 212 |
-
)
|
| 213 |
-
self.mlp = nn.Sequential(
|
| 214 |
-
nn.Linear(128, 64),
|
| 215 |
-
nn.GELU(),
|
| 216 |
-
nn.Linear(64, 128),
|
| 217 |
-
)
|
| 218 |
-
|
| 219 |
-
self.prompt_fusion = nn.Sequential(
|
| 220 |
-
Block(128, 4),
|
| 221 |
-
Block(128, 4),
|
| 222 |
-
Block(128, 4),
|
| 223 |
-
)
|
| 224 |
-
|
| 225 |
-
dpr = [x.item() for x in torch.linspace(0, 0, 8)] # stochastic depth decay rule
|
| 226 |
-
self.blocks = nn.Sequential(*[
|
| 227 |
-
Block(
|
| 228 |
-
dim=128, num_heads=4, mlp_ratio=4, qkv_bias=True, drop=0,
|
| 229 |
-
attn_drop=0, drop_path=dpr[i], norm_layer=partial(nn.LayerNorm, eps=1e-6), act_layer=nn.GELU)
|
| 230 |
-
for i in range(8)])
|
| 231 |
-
self.norm = nn.LayerNorm(128)
|
| 232 |
-
|
| 233 |
-
self.score_block = nn.Sequential(
|
| 234 |
-
nn.Linear(128, 128 // 2),
|
| 235 |
-
nn.ReLU(),
|
| 236 |
-
nn.Dropout(drop),
|
| 237 |
-
nn.Linear(128 // 2, 1),
|
| 238 |
-
nn.Sigmoid()
|
| 239 |
-
)
|
| 240 |
-
|
| 241 |
-
self.prompt_feature = {}
|
| 242 |
-
|
| 243 |
-
@torch.no_grad()
|
| 244 |
-
def clear(self):
|
| 245 |
-
self.prompt_feature = {}
|
| 246 |
-
|
| 247 |
-
@torch.no_grad()
|
| 248 |
-
def inference(self, x, data_type):
|
| 249 |
-
prompt_feature = self.prompt_feature[data_type] # 1, n, 128
|
| 250 |
-
|
| 251 |
-
_x = self.vit(x)
|
| 252 |
-
x = self.extract_feature(self.save_output) # bs, 28 * 28, 768 * 4
|
| 253 |
-
self.save_output.outputs.clear()
|
| 254 |
-
|
| 255 |
-
x = x.permute(0, 2, 1).contiguous() # bs, 768 * 4, 28 * 28
|
| 256 |
-
x = rearrange(x, 'b (d n) (w h) -> b d n w h', d=4, n=self.dim_mlp, w=self.input_size, h=self.input_size) # bs, 4, 768, 28, 28
|
| 257 |
-
x = x.permute(1, 0, 2, 3, 4).contiguous() # bs, 4, 768, 28 * 28
|
| 258 |
-
|
| 259 |
-
# Different Opinion Features (DOF)
|
| 260 |
-
DOF = torch.tensor([]).cuda()
|
| 261 |
-
for index, _ in enumerate(self.MALs):
|
| 262 |
-
DOF = torch.cat((DOF, self.MALs[index](x).unsqueeze(0)), dim=0)
|
| 263 |
-
DOF = rearrange(DOF, 'n c d (w h) -> n c d w h', w=self.input_size, h=self.input_size) # M, bs, 768, 28, 28
|
| 264 |
-
|
| 265 |
-
# Image Quality Score Regression
|
| 266 |
-
fusion_mal = self.fusion_mal(DOF).permute(0, 2, 1).contiguous() # bs, 28 * 28 768
|
| 267 |
-
IQ_feature = self.block(fusion_mal).permute(0, 2, 1).contiguous() # bs, 768, 28 * 28
|
| 268 |
-
IQ_feature = rearrange(IQ_feature, 'c d (w h) -> c d w h', w=self.input_size, h=self.input_size) # bs, 768, 28, 28
|
| 269 |
-
img_feature = self.cnn(IQ_feature).squeeze(-1).squeeze(-1).unsqueeze(1) # bs, 1, 128
|
| 270 |
-
|
| 271 |
-
prompt_feature = prompt_feature.repeat(img_feature.shape[0], 1, 1) # bs, n, 128
|
| 272 |
-
prompt_feature = self.prompt_fusion(prompt_feature) # bs, n, 128
|
| 273 |
-
|
| 274 |
-
fusion = self.blocks(torch.cat((img_feature, prompt_feature), dim=1)) # bs, 2, 1
|
| 275 |
-
fusion = self.norm(fusion)
|
| 276 |
-
fusion = self.score_block(fusion)
|
| 277 |
-
|
| 278 |
-
# iq_res = torch.mean(fusion, dim=1).view(-1)
|
| 279 |
-
iq_res = fusion[:, 0].view(-1)
|
| 280 |
-
|
| 281 |
-
return iq_res
|
| 282 |
-
|
| 283 |
-
@torch.no_grad()
|
| 284 |
-
def check_prompt(self, data_type):
|
| 285 |
-
return data_type in self.prompt_feature
|
| 286 |
-
|
| 287 |
-
@torch.no_grad()
|
| 288 |
-
def forward_prompt(self, x, score, data_type):
|
| 289 |
-
if data_type in self.prompt_feature:
|
| 290 |
-
return
|
| 291 |
-
_x = self.vit(x)
|
| 292 |
-
x = self.extract_feature(self.save_output) # bs, 28 * 28, 768 * 4
|
| 293 |
-
self.save_output.outputs.clear()
|
| 294 |
-
|
| 295 |
-
x = x.permute(0, 2, 1).contiguous() # bs, 768 * 4, 28 * 28
|
| 296 |
-
x = rearrange(x, 'b (d n) (w h) -> b d n w h', d=4, n=self.dim_mlp, w=self.input_size, h=self.input_size) # bs, 4, 768, 28, 28
|
| 297 |
-
x = x.permute(1, 0, 2, 3, 4).contiguous() # 4, bs, 768, 28, 28
|
| 298 |
-
|
| 299 |
-
# Different Opinion Features (DOF)
|
| 300 |
-
DOF = torch.tensor([]).cuda()
|
| 301 |
-
for index, _ in enumerate(self.MALs):
|
| 302 |
-
DOF = torch.cat((DOF, self.MALs[index](x).unsqueeze(0)), dim=0)
|
| 303 |
-
DOF = rearrange(DOF, 'n c d (w h) -> n c d w h', w=self.input_size, h=self.input_size) # M, bs, 768, 28, 28
|
| 304 |
-
|
| 305 |
-
# Image Quality Score Regression
|
| 306 |
-
fusion_mal = self.fusion_mal(DOF).permute(0, 2, 1).contiguous() # bs, 28 * 28 768
|
| 307 |
-
IQ_feature = self.block(fusion_mal).permute(0, 2, 1).contiguous() # bs, 768, 28 * 28
|
| 308 |
-
IQ_feature = rearrange(IQ_feature, 'c d (w h) -> c d w h', w=self.input_size, h=self.input_size) # bs, 768, 28, 28
|
| 309 |
-
img_feature = self.cnn(IQ_feature).squeeze(-1).squeeze(-1).unsqueeze(1) # bs, 1, 128
|
| 310 |
-
|
| 311 |
-
# 分数线性变换为128维
|
| 312 |
-
# score_feature = self.score_projection(score) # bs, 128
|
| 313 |
-
score_feature = score.expand(-1, 128)
|
| 314 |
-
|
| 315 |
-
# img_feature 和 score_feature融合得到 funsion_feature
|
| 316 |
-
funsion_feature = self.i_p_fusion(torch.cat((img_feature, score_feature.unsqueeze(1)), dim=1)) # bs, 2, 128
|
| 317 |
-
funsion_feature = self.mlp(torch.mean(funsion_feature, dim=1)).unsqueeze(0) # 1, n, 128
|
| 318 |
-
|
| 319 |
-
# print('Load Prompt For Testing.', funsion_feature.shape)
|
| 320 |
-
# self.prompt_feature = funsion_feature.clone()
|
| 321 |
-
self.prompt_feature[data_type] = funsion_feature.clone()
|
| 322 |
-
|
| 323 |
-
def forward(self, x, score):
|
| 324 |
-
_x = self.vit(x)
|
| 325 |
-
x = self.extract_feature(self.save_output) # bs, 28 * 28, 768 * 4
|
| 326 |
-
self.save_output.outputs.clear()
|
| 327 |
-
|
| 328 |
-
x = x.permute(0, 2, 1).contiguous() # bs, 768 * 4, 28 * 28
|
| 329 |
-
x = rearrange(x, 'b (d n) (w h) -> b d n w h', d=4, n=self.dim_mlp, w=self.input_size, h=self.input_size) # bs, 4, 768, 28, 28
|
| 330 |
-
x = x.permute(1, 0, 2, 3, 4).contiguous() # 4, bs, 768, 28, 28
|
| 331 |
-
|
| 332 |
-
# Different Opinion Features (DOF)
|
| 333 |
-
DOF = torch.tensor([]).cuda()
|
| 334 |
-
for index, _ in enumerate(self.MALs):
|
| 335 |
-
DOF = torch.cat((DOF, self.MALs[index](x).unsqueeze(0)), dim=0)
|
| 336 |
-
DOF = rearrange(DOF, 'n c d (w h) -> n c d w h', w=self.input_size, h=self.input_size) # M, bs, 768, 28, 28
|
| 337 |
-
|
| 338 |
-
# Image Quality Score Regression
|
| 339 |
-
fusion_mal = self.fusion_mal(DOF).permute(0, 2, 1).contiguous() # bs, 28 * 28 768
|
| 340 |
-
IQ_feature = self.block(fusion_mal).permute(0, 2, 1).contiguous() # bs, 768, 28 * 28
|
| 341 |
-
IQ_feature = rearrange(IQ_feature, 'c d (w h) -> c d w h', w=self.input_size, h=self.input_size) # bs, 768, 28, 28
|
| 342 |
-
img_feature = self.cnn(IQ_feature).squeeze(-1).squeeze(-1).unsqueeze(1) # bs, 1, 128
|
| 343 |
-
|
| 344 |
-
# 分数线性变换为128维
|
| 345 |
-
# score_feature = self.score_projection(score) # bs, 128
|
| 346 |
-
score_feature = score.expand(-1, 128) # bs, 128
|
| 347 |
-
|
| 348 |
-
# img_feature 和 score_feature融合得到 funsion_feature
|
| 349 |
-
# funsion_feature = self.i_p_fusion(torch.cat((img_feature.detach(), score_feature.unsqueeze(1).detach()), dim=1)) # bs, 2, 128
|
| 350 |
-
funsion_feature = self.i_p_fusion(torch.cat((img_feature, score_feature.unsqueeze(1)), dim=1)) # bs, 2, 128
|
| 351 |
-
funsion_feature = self.mlp(torch.mean(funsion_feature, dim=1)) #bs, 128
|
| 352 |
-
funsion_feature = self.expand(funsion_feature) # bs, bs - 1, 128
|
| 353 |
-
funsion_feature = self.prompt_fusion(funsion_feature) # bs, bs - 1, 128
|
| 354 |
-
|
| 355 |
-
fusion = self.blocks(torch.cat((img_feature, funsion_feature), dim=1)) # bs, 2, 1
|
| 356 |
-
fusion = self.norm(fusion)
|
| 357 |
-
fusion = self.score_block(fusion)
|
| 358 |
-
# iq_res = torch.mean(fusion, dim=1).view(-1)
|
| 359 |
-
iq_res = fusion[:, 0].view(-1)
|
| 360 |
-
|
| 361 |
-
# differ_fusion = self.diff_block(torch.cat((img_feature, funsion_feature), dim=1)) # bs, n, 1
|
| 362 |
-
# differ_iq_res = torch.mean(differ_fusion, dim=1).view(-1)
|
| 363 |
-
|
| 364 |
-
gt_res = score.view(-1)
|
| 365 |
-
# diff_gt_res = 1 - score.view(-1)
|
| 366 |
-
|
| 367 |
-
return (iq_res, 'differ_iq_res'), (gt_res, 'diff_gt_res')
|
| 368 |
-
|
| 369 |
-
def extract_feature(self, save_output, block_index=None):
|
| 370 |
-
block_index = [2, 5, 8, 11]
|
| 371 |
-
x1 = save_output.outputs[block_index[0]][:, 1:]
|
| 372 |
-
x2 = save_output.outputs[block_index[1]][:, 1:]
|
| 373 |
-
x3 = save_output.outputs[block_index[2]][:, 1:]
|
| 374 |
-
x4 = save_output.outputs[block_index[3]][:, 1:]
|
| 375 |
-
x = torch.cat((x1, x2, x3, x4), dim=2)
|
| 376 |
-
return x
|
| 377 |
-
|
| 378 |
-
def expand(self, A):
|
| 379 |
-
A_expanded = A.unsqueeze(0).expand(A.size(0), -1, -1)
|
| 380 |
-
|
| 381 |
-
B = None
|
| 382 |
-
for index, i in enumerate(A_expanded):
|
| 383 |
-
rmv = torch.cat((i[:index], i[index + 1:])).unsqueeze(0)
|
| 384 |
-
if B is None:
|
| 385 |
-
B = rmv
|
| 386 |
-
else:
|
| 387 |
-
B = torch.cat((B, rmv), dim=0)
|
| 388 |
-
|
| 389 |
-
return B
|
| 390 |
-
|
| 391 |
-
|
| 392 |
-
if __name__ == '__main__':
|
| 393 |
-
in_feature = torch.zeros((11, 3, 384, 384)).cuda()
|
| 394 |
-
gt_feature = torch.tensor(
|
| 395 |
-
[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11], dtype=torch.float).cuda()
|
| 396 |
-
gt_feature = gt_feature.reshape(-1, 1)
|
| 397 |
-
model = MoNet().cuda()
|
| 398 |
-
|
| 399 |
-
(iq_res, _), (_, _) = model(in_feature, gt_feature)
|
| 400 |
-
|
| 401 |
-
print(iq_res.shape)
|
| 402 |
-
# print(gt_res.shape)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
PromptIQA/models/vit_large.py
DELETED
|
@@ -1,405 +0,0 @@
|
|
| 1 |
-
"""
|
| 2 |
-
The completion for Mean-opinion Network(MoNet)
|
| 3 |
-
"""
|
| 4 |
-
import torch
|
| 5 |
-
import torch.nn as nn
|
| 6 |
-
import timm
|
| 7 |
-
|
| 8 |
-
from timm.models.vision_transformer import Block
|
| 9 |
-
from einops import rearrange
|
| 10 |
-
from itertools import combinations
|
| 11 |
-
|
| 12 |
-
from tqdm import tqdm
|
| 13 |
-
|
| 14 |
-
|
| 15 |
-
class Attention_Block(nn.Module):
|
| 16 |
-
def __init__(self, dim, drop=0.1):
|
| 17 |
-
super().__init__()
|
| 18 |
-
self.c_q = nn.Linear(dim, dim)
|
| 19 |
-
self.c_k = nn.Linear(dim, dim)
|
| 20 |
-
self.c_v = nn.Linear(dim, dim)
|
| 21 |
-
self.norm_fact = dim ** -0.5
|
| 22 |
-
self.softmax = nn.Softmax(dim=-1)
|
| 23 |
-
self.proj_drop = nn.Dropout(drop)
|
| 24 |
-
|
| 25 |
-
def forward(self, x):
|
| 26 |
-
_x = x
|
| 27 |
-
B, C, N = x.shape
|
| 28 |
-
q = self.c_q(x)
|
| 29 |
-
k = self.c_k(x)
|
| 30 |
-
v = self.c_v(x)
|
| 31 |
-
|
| 32 |
-
attn = q @ k.transpose(-2, -1) * self.norm_fact
|
| 33 |
-
attn = self.softmax(attn)
|
| 34 |
-
x = (attn @ v).transpose(1, 2).reshape(B, C, N)
|
| 35 |
-
x = self.proj_drop(x)
|
| 36 |
-
x = x + _x
|
| 37 |
-
return x
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
class Self_Attention(nn.Module):
|
| 41 |
-
""" Self attention Layer"""
|
| 42 |
-
|
| 43 |
-
def __init__(self, in_dim):
|
| 44 |
-
super(Self_Attention, self).__init__()
|
| 45 |
-
|
| 46 |
-
self.qConv = nn.Conv2d(in_channels=in_dim, out_channels=in_dim // 8, kernel_size=1)
|
| 47 |
-
self.kConv = nn.Conv2d(in_channels=in_dim, out_channels=in_dim // 8, kernel_size=1)
|
| 48 |
-
self.vConv = nn.Conv2d(in_channels=in_dim, out_channels=in_dim, kernel_size=1)
|
| 49 |
-
self.gamma = nn.Parameter(torch.zeros(1))
|
| 50 |
-
|
| 51 |
-
self.softmax = nn.Softmax(dim=-1)
|
| 52 |
-
|
| 53 |
-
def forward(self, inFeature):
|
| 54 |
-
bs, C, w, h = inFeature.size()
|
| 55 |
-
|
| 56 |
-
proj_query = self.qConv(inFeature).view(bs, -1, w * h).permute(0, 2, 1).contiguous()
|
| 57 |
-
proj_key = self.kConv(inFeature).view(bs, -1, w * h)
|
| 58 |
-
energy = torch.bmm(proj_query, proj_key)
|
| 59 |
-
attention = self.softmax(energy)
|
| 60 |
-
proj_value = self.vConv(inFeature).view(bs, -1, w * h)
|
| 61 |
-
|
| 62 |
-
out = torch.bmm(proj_value, attention.permute(0, 2, 1).contiguous())
|
| 63 |
-
out = out.view(bs, C, w, h)
|
| 64 |
-
|
| 65 |
-
out = self.gamma * out + inFeature
|
| 66 |
-
|
| 67 |
-
return out
|
| 68 |
-
|
| 69 |
-
|
| 70 |
-
class three_cnn(nn.Module):
|
| 71 |
-
def __init__(self, in_dim) -> None:
|
| 72 |
-
super().__init__()
|
| 73 |
-
|
| 74 |
-
self.three_cnn = nn.Sequential(
|
| 75 |
-
nn.Conv2d(in_dim, in_dim // 2, kernel_size=3, padding=1),
|
| 76 |
-
nn.ReLU(inplace=True),
|
| 77 |
-
nn.Conv2d(in_dim // 2, in_dim // 2, kernel_size=3, padding=1),
|
| 78 |
-
nn.ReLU(inplace=True),
|
| 79 |
-
nn.Conv2d(in_dim // 2, in_dim, kernel_size=3, padding=1),
|
| 80 |
-
nn.ReLU(inplace=True),
|
| 81 |
-
)
|
| 82 |
-
|
| 83 |
-
def forward(self, input):
|
| 84 |
-
return self.three_cnn(input)
|
| 85 |
-
|
| 86 |
-
|
| 87 |
-
class MAL(nn.Module):
|
| 88 |
-
def __init__(self, in_dim=768, feature_num=4, feature_size=28):
|
| 89 |
-
super().__init__()
|
| 90 |
-
self.attention_module = nn.ModuleList()
|
| 91 |
-
for i in range(feature_num):
|
| 92 |
-
self.attention_module.append(three_cnn(in_dim))
|
| 93 |
-
|
| 94 |
-
self.feature_num = feature_num
|
| 95 |
-
self.in_dim = in_dim
|
| 96 |
-
self.feature_size = feature_size
|
| 97 |
-
|
| 98 |
-
def forward(self, features):
|
| 99 |
-
feature = torch.tensor([]).cuda()
|
| 100 |
-
for index, _ in enumerate(features):
|
| 101 |
-
feature = torch.cat((feature, self.attention_module[index](features[index]).unsqueeze(1)), dim=1)
|
| 102 |
-
feature = torch.mean(feature, dim=1)
|
| 103 |
-
features = feature.view(-1, self.in_dim, self.feature_size * self.feature_size)
|
| 104 |
-
|
| 105 |
-
return features # bs, 768, 28 * 28
|
| 106 |
-
|
| 107 |
-
|
| 108 |
-
class SaveOutput:
|
| 109 |
-
def __init__(self):
|
| 110 |
-
self.outputs = []
|
| 111 |
-
|
| 112 |
-
def __call__(self, module, module_in, module_out):
|
| 113 |
-
self.outputs.append(module_out)
|
| 114 |
-
|
| 115 |
-
def clear(self):
|
| 116 |
-
self.outputs = []
|
| 117 |
-
|
| 118 |
-
|
| 119 |
-
# utils
|
| 120 |
-
@torch.no_grad()
|
| 121 |
-
def concat_all_gather(tensor):
|
| 122 |
-
"""
|
| 123 |
-
Performs all_gather operation on the provided tensors.
|
| 124 |
-
*** Warning ***: torch.distributed.all_gather has no gradient.
|
| 125 |
-
"""
|
| 126 |
-
tensors_gather = [
|
| 127 |
-
torch.ones_like(tensor) for _ in range(torch.distributed.get_world_size())
|
| 128 |
-
]
|
| 129 |
-
torch.distributed.all_gather(tensors_gather, tensor, async_op=False)
|
| 130 |
-
|
| 131 |
-
output = torch.cat(tensors_gather, dim=0)
|
| 132 |
-
return output
|
| 133 |
-
|
| 134 |
-
|
| 135 |
-
class Attention(nn.Module):
|
| 136 |
-
def __init__(self, dim, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0.):
|
| 137 |
-
super().__init__()
|
| 138 |
-
assert dim % num_heads == 0, 'dim should be divisible by num_heads'
|
| 139 |
-
self.num_heads = num_heads
|
| 140 |
-
head_dim = dim // num_heads
|
| 141 |
-
self.scale = head_dim ** -0.5
|
| 142 |
-
|
| 143 |
-
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
| 144 |
-
self.attn_drop = nn.Dropout(attn_drop)
|
| 145 |
-
self.proj = nn.Linear(dim, dim)
|
| 146 |
-
self.proj_drop = nn.Dropout(proj_drop)
|
| 147 |
-
|
| 148 |
-
def forward(self, x):
|
| 149 |
-
B, N, C = x.shape
|
| 150 |
-
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
|
| 151 |
-
q, k, v = qkv.unbind(0) # make torchscript happy (cannot use tensor as tuple)
|
| 152 |
-
|
| 153 |
-
attn = (q @ k.transpose(-2, -1)) * self.scale
|
| 154 |
-
attn = attn.softmax(dim=-1)
|
| 155 |
-
attn = self.attn_drop(attn)
|
| 156 |
-
|
| 157 |
-
x = (attn @ v).transpose(1, 2).reshape(B, N, C)
|
| 158 |
-
x = self.proj(x)
|
| 159 |
-
x = self.proj_drop(x)
|
| 160 |
-
return x
|
| 161 |
-
|
| 162 |
-
|
| 163 |
-
from functools import partial
|
| 164 |
-
|
| 165 |
-
|
| 166 |
-
class MoNet(nn.Module):
|
| 167 |
-
def __init__(self, patch_size=32, drop=0.1, dim_mlp=1024, img_size=384):
|
| 168 |
-
super().__init__()
|
| 169 |
-
self.img_size = img_size
|
| 170 |
-
self.input_size = img_size // patch_size
|
| 171 |
-
self.dim_mlp = dim_mlp
|
| 172 |
-
|
| 173 |
-
self.vit = timm.create_model('vit_large_patch32_384', pretrained=True)
|
| 174 |
-
self.vit.norm = nn.Identity()
|
| 175 |
-
self.vit.head = nn.Identity()
|
| 176 |
-
self.vit.head_drop = nn.Identity()
|
| 177 |
-
|
| 178 |
-
self.save_output = SaveOutput()
|
| 179 |
-
|
| 180 |
-
# Register Hooks
|
| 181 |
-
hook_handles = []
|
| 182 |
-
for layer in self.vit.modules():
|
| 183 |
-
if isinstance(layer, Block):
|
| 184 |
-
handle = layer.register_forward_hook(self.save_output)
|
| 185 |
-
hook_handles.append(handle)
|
| 186 |
-
|
| 187 |
-
self.MALs = nn.ModuleList()
|
| 188 |
-
for _ in range(3):
|
| 189 |
-
self.MALs.append(MAL(in_dim=dim_mlp, feature_size=self.input_size))
|
| 190 |
-
|
| 191 |
-
# Image Quality Score Regression
|
| 192 |
-
self.fusion_mal = MAL(in_dim=dim_mlp, feature_num=3, feature_size=self.input_size)
|
| 193 |
-
self.block = Block(dim_mlp, 16)
|
| 194 |
-
self.cnn = nn.Sequential(
|
| 195 |
-
nn.Conv2d(dim_mlp, 512, 5),
|
| 196 |
-
nn.BatchNorm2d(512),
|
| 197 |
-
nn.ReLU(inplace=True),
|
| 198 |
-
nn.AvgPool2d((2, 2)), # 4
|
| 199 |
-
nn.Conv2d(512, 256, 3, 1), # 2
|
| 200 |
-
nn.BatchNorm2d(256),
|
| 201 |
-
nn.ReLU(inplace=True),
|
| 202 |
-
nn.Conv2d(256, 256, 1),
|
| 203 |
-
nn.BatchNorm2d(256),
|
| 204 |
-
nn.ReLU(inplace=True),
|
| 205 |
-
nn.AvgPool2d((2, 2)),
|
| 206 |
-
)
|
| 207 |
-
|
| 208 |
-
self.i_p_fusion = nn.Sequential(
|
| 209 |
-
Block(256, 8),
|
| 210 |
-
Block(256, 8),
|
| 211 |
-
Block(256, 8),
|
| 212 |
-
)
|
| 213 |
-
self.mlp = nn.Sequential(
|
| 214 |
-
nn.Linear(256, 128),
|
| 215 |
-
nn.GELU(),
|
| 216 |
-
nn.Linear(128, 256),
|
| 217 |
-
)
|
| 218 |
-
|
| 219 |
-
self.prompt_fusion = nn.Sequential(
|
| 220 |
-
Block(256, 8),
|
| 221 |
-
Block(256, 8),
|
| 222 |
-
Block(256, 8),
|
| 223 |
-
)
|
| 224 |
-
|
| 225 |
-
dpr = [x.item() for x in torch.linspace(0, 0, 8)] # stochastic depth decay rule
|
| 226 |
-
self.blocks = nn.Sequential(*[
|
| 227 |
-
Block(dim=256, num_heads=8, mlp_ratio=4, qkv_bias=True, attn_drop=0, drop_path=dpr[i])
|
| 228 |
-
for i in range(8)])
|
| 229 |
-
self.norm = nn.LayerNorm(256)
|
| 230 |
-
|
| 231 |
-
self.score_block = nn.Sequential(
|
| 232 |
-
nn.Linear(256, 256 // 2),
|
| 233 |
-
nn.ReLU(),
|
| 234 |
-
nn.Dropout(drop),
|
| 235 |
-
nn.Linear(256 // 2, 1),
|
| 236 |
-
nn.Sigmoid()
|
| 237 |
-
)
|
| 238 |
-
self.prompt_feature = {}
|
| 239 |
-
|
| 240 |
-
@torch.no_grad()
|
| 241 |
-
def clear(self):
|
| 242 |
-
self.prompt_feature = {}
|
| 243 |
-
|
| 244 |
-
@torch.no_grad()
|
| 245 |
-
def inference(self, x, data_type):
|
| 246 |
-
prompt_feature = self.prompt_feature[data_type] # 1, n, 128
|
| 247 |
-
|
| 248 |
-
_x = self.vit(x)
|
| 249 |
-
x = self.extract_feature(self.save_output) # bs, 28 * 28, 768 * 4
|
| 250 |
-
self.save_output.outputs.clear()
|
| 251 |
-
|
| 252 |
-
x = x.permute(0, 2, 1).contiguous() # bs, 768 * 4, 28 * 28
|
| 253 |
-
x = rearrange(x, 'b (d n) (w h) -> b d n w h', d=4, n=self.dim_mlp, w=self.input_size,
|
| 254 |
-
h=self.input_size) # bs, 4, 768, 28, 28
|
| 255 |
-
x = x.permute(1, 0, 2, 3, 4).contiguous() # bs, 4, 768, 28 * 28
|
| 256 |
-
|
| 257 |
-
# Different Opinion Features (DOF)
|
| 258 |
-
DOF = torch.tensor([]).cuda()
|
| 259 |
-
for index, _ in enumerate(self.MALs):
|
| 260 |
-
DOF = torch.cat((DOF, self.MALs[index](x).unsqueeze(0)), dim=0)
|
| 261 |
-
DOF = rearrange(DOF, 'n c d (w h) -> n c d w h', w=self.input_size, h=self.input_size) # M, bs, 768, 28, 28
|
| 262 |
-
|
| 263 |
-
# Image Quality Score Regression
|
| 264 |
-
fusion_mal = self.fusion_mal(DOF).permute(0, 2, 1).contiguous() # bs, 28 * 28 768
|
| 265 |
-
IQ_feature = self.block(fusion_mal).permute(0, 2, 1).contiguous() # bs, 768, 28 * 28
|
| 266 |
-
IQ_feature = rearrange(IQ_feature, 'c d (w h) -> c d w h', w=self.input_size,
|
| 267 |
-
h=self.input_size) # bs, 768, 28, 28
|
| 268 |
-
img_feature = self.cnn(IQ_feature).squeeze(-1).squeeze(-1).unsqueeze(1) # bs, 1, 128
|
| 269 |
-
|
| 270 |
-
prompt_feature = prompt_feature.repeat(img_feature.shape[0], 1, 1) # bs, n, 128
|
| 271 |
-
prompt_feature = self.prompt_fusion(prompt_feature) # bs, n, 128
|
| 272 |
-
|
| 273 |
-
fusion = self.blocks(torch.cat((img_feature, prompt_feature), dim=1)) # bs, 2, 1
|
| 274 |
-
fusion = self.norm(fusion)
|
| 275 |
-
fusion = self.score_block(fusion)
|
| 276 |
-
|
| 277 |
-
# iq_res = torch.mean(fusion, dim=1).view(-1)
|
| 278 |
-
iq_res = fusion[:, 0].view(-1)
|
| 279 |
-
|
| 280 |
-
return iq_res
|
| 281 |
-
|
| 282 |
-
@torch.no_grad()
|
| 283 |
-
def check_prompt(self, data_type):
|
| 284 |
-
return data_type in self.prompt_feature
|
| 285 |
-
|
| 286 |
-
@torch.no_grad()
|
| 287 |
-
def forward_prompt(self, x, score, data_type):
|
| 288 |
-
if data_type in self.prompt_feature:
|
| 289 |
-
return
|
| 290 |
-
_x = self.vit(x)
|
| 291 |
-
x = self.extract_feature(self.save_output) # bs, 28 * 28, 768 * 4
|
| 292 |
-
self.save_output.outputs.clear()
|
| 293 |
-
|
| 294 |
-
x = x.permute(0, 2, 1).contiguous() # bs, 768 * 4, 28 * 28
|
| 295 |
-
x = rearrange(x, 'b (d n) (w h) -> b d n w h', d=4, n=self.dim_mlp, w=self.input_size,
|
| 296 |
-
h=self.input_size) # bs, 4, 768, 28, 28
|
| 297 |
-
x = x.permute(1, 0, 2, 3, 4).contiguous() # 4, bs, 768, 28, 28
|
| 298 |
-
|
| 299 |
-
# Different Opinion Features (DOF)
|
| 300 |
-
DOF = torch.tensor([]).cuda()
|
| 301 |
-
for index, _ in enumerate(self.MALs):
|
| 302 |
-
DOF = torch.cat((DOF, self.MALs[index](x).unsqueeze(0)), dim=0)
|
| 303 |
-
DOF = rearrange(DOF, 'n c d (w h) -> n c d w h', w=self.input_size, h=self.input_size) # M, bs, 768, 28, 28
|
| 304 |
-
|
| 305 |
-
# Image Quality Score Regression
|
| 306 |
-
fusion_mal = self.fusion_mal(DOF).permute(0, 2, 1).contiguous() # bs, 28 * 28 768
|
| 307 |
-
IQ_feature = self.block(fusion_mal).permute(0, 2, 1).contiguous() # bs, 768, 28 * 28
|
| 308 |
-
IQ_feature = rearrange(IQ_feature, 'c d (w h) -> c d w h', w=self.input_size,
|
| 309 |
-
h=self.input_size) # bs, 768, 28, 28
|
| 310 |
-
img_feature = self.cnn(IQ_feature).squeeze(-1).squeeze(-1).unsqueeze(1) # bs, 1, 128
|
| 311 |
-
|
| 312 |
-
# 分数线性变换为128维
|
| 313 |
-
# score_feature = self.score_projection(score) # bs, 128
|
| 314 |
-
score_feature = score.expand(-1, 256)
|
| 315 |
-
|
| 316 |
-
# img_feature 和 score_feature融合得到 funsion_feature
|
| 317 |
-
funsion_feature = self.i_p_fusion(torch.cat((img_feature, score_feature.unsqueeze(1)), dim=1)) # bs, 2, 128
|
| 318 |
-
funsion_feature = self.mlp(torch.mean(funsion_feature, dim=1)).unsqueeze(0) # 1, n, 128
|
| 319 |
-
|
| 320 |
-
# print('Load Prompt For Testing.', funsion_feature.shape)
|
| 321 |
-
# self.prompt_feature = funsion_feature.clone()
|
| 322 |
-
self.prompt_feature[data_type] = funsion_feature.clone()
|
| 323 |
-
|
| 324 |
-
def forward(self, x, score):
|
| 325 |
-
_x = self.vit(x)
|
| 326 |
-
x = self.extract_feature(self.save_output) # bs, 28 * 28, 768 * 4
|
| 327 |
-
self.save_output.outputs.clear()
|
| 328 |
-
|
| 329 |
-
x = x.permute(0, 2, 1).contiguous() # bs, 768 * 4, 28 * 28
|
| 330 |
-
x = rearrange(x, 'b (d n) (w h) -> b d n w h', d=4, n=self.dim_mlp, w=self.input_size,
|
| 331 |
-
h=self.input_size) # bs, 4, 768, 28, 28
|
| 332 |
-
x = x.permute(1, 0, 2, 3, 4).contiguous() # 4, bs, 768, 28, 28
|
| 333 |
-
|
| 334 |
-
# Different Opinion Features (DOF)
|
| 335 |
-
DOF = torch.tensor([]).cuda()
|
| 336 |
-
for index, _ in enumerate(self.MALs):
|
| 337 |
-
DOF = torch.cat((DOF, self.MALs[index](x).unsqueeze(0)), dim=0)
|
| 338 |
-
DOF = rearrange(DOF, 'n c d (w h) -> n c d w h', w=self.input_size, h=self.input_size) # M, bs, 768, 28, 28
|
| 339 |
-
# Image Quality Score Regression
|
| 340 |
-
fusion_mal = self.fusion_mal(DOF).permute(0, 2, 1).contiguous() # bs, 28 * 28 768
|
| 341 |
-
IQ_feature = self.block(fusion_mal).permute(0, 2, 1).contiguous() # bs, 768, 28 * 28
|
| 342 |
-
IQ_feature = rearrange(IQ_feature, 'c d (w h) -> c d w h', w=self.input_size,
|
| 343 |
-
h=self.input_size) # bs, 768, 28, 28
|
| 344 |
-
img_feature = self.cnn(IQ_feature).squeeze(-1).squeeze(-1).unsqueeze(1) # bs, 1, 128
|
| 345 |
-
|
| 346 |
-
# 分数线性变换为128维
|
| 347 |
-
# score_feature = self.score_projection(score) # bs, 128
|
| 348 |
-
score_feature = score.expand(-1, 256) # bs, 128
|
| 349 |
-
|
| 350 |
-
# img_feature 和 score_feature融合得到 funsion_feature funsion_feature = self.i_p_fusion(torch.cat((
|
| 351 |
-
# img_feature.detach(), score_feature.unsqueeze(1).detach()), dim=1)) # bs, 2, 128
|
| 352 |
-
funsion_feature = self.i_p_fusion(torch.cat((img_feature, score_feature.unsqueeze(1)), dim=1)) # bs, 2, 128
|
| 353 |
-
funsion_feature = self.mlp(torch.mean(funsion_feature, dim=1)) # bs, 128
|
| 354 |
-
funsion_feature = self.expand(funsion_feature) # bs, bs - 1, 128
|
| 355 |
-
funsion_feature = self.prompt_fusion(funsion_feature) # bs, bs - 1, 128
|
| 356 |
-
|
| 357 |
-
fusion = self.blocks(torch.cat((img_feature, funsion_feature), dim=1)) # bs, 2, 1
|
| 358 |
-
fusion = self.norm(fusion)
|
| 359 |
-
fusion = self.score_block(fusion)
|
| 360 |
-
# iq_res = torch.mean(fusion, dim=1).view(-1)
|
| 361 |
-
iq_res = fusion[:, 0].view(-1)
|
| 362 |
-
|
| 363 |
-
# differ_fusion = self.diff_block(torch.cat((img_feature, funsion_feature), dim=1)) # bs, n, 1
|
| 364 |
-
# differ_iq_res = torch.mean(differ_fusion, dim=1).view(-1)
|
| 365 |
-
|
| 366 |
-
gt_res = score.view(-1)
|
| 367 |
-
# diff_gt_res = 1 - score.view(-1)
|
| 368 |
-
|
| 369 |
-
return (iq_res, 'differ_iq_res'), (gt_res, 'diff_gt_res')
|
| 370 |
-
|
| 371 |
-
def extract_feature(self, save_output, block_index=None):
|
| 372 |
-
if block_index is None:
|
| 373 |
-
block_index = [5, 11, 17, 23]
|
| 374 |
-
x1 = save_output.outputs[block_index[0]][:, 1:]
|
| 375 |
-
x2 = save_output.outputs[block_index[1]][:, 1:]
|
| 376 |
-
x3 = save_output.outputs[block_index[2]][:, 1:]
|
| 377 |
-
x4 = save_output.outputs[block_index[3]][:, 1:]
|
| 378 |
-
x = torch.cat((x1, x2, x3, x4), dim=2)
|
| 379 |
-
return x
|
| 380 |
-
|
| 381 |
-
def expand(self, A):
|
| 382 |
-
A_expanded = A.unsqueeze(0).expand(A.size(0), -1, -1)
|
| 383 |
-
|
| 384 |
-
B = None
|
| 385 |
-
for index, i in enumerate(A_expanded):
|
| 386 |
-
rmv = torch.cat((i[:index], i[index + 1:])).unsqueeze(0)
|
| 387 |
-
if B is None:
|
| 388 |
-
B = rmv
|
| 389 |
-
else:
|
| 390 |
-
B = torch.cat((B, rmv), dim=0)
|
| 391 |
-
|
| 392 |
-
return B
|
| 393 |
-
|
| 394 |
-
|
| 395 |
-
if __name__ == '__main__':
|
| 396 |
-
in_feature = torch.zeros((11, 3, 384, 384)).cuda()
|
| 397 |
-
gt_feature = torch.tensor(
|
| 398 |
-
[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11], dtype=torch.float).cuda()
|
| 399 |
-
gt_feature = gt_feature.reshape(-1, 1)
|
| 400 |
-
model = MoNet().cuda()
|
| 401 |
-
|
| 402 |
-
(iq_res, _), (_, _) = model(in_feature, gt_feature)
|
| 403 |
-
|
| 404 |
-
print(iq_res.shape)
|
| 405 |
-
# print(gt_res.shape)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
PromptIQA/run_promptIQA copy.py
DELETED
|
@@ -1,109 +0,0 @@
|
|
| 1 |
-
import os
|
| 2 |
-
import random
|
| 3 |
-
import torchvision
|
| 4 |
-
import cv2
|
| 5 |
-
import torch
|
| 6 |
-
from models import monet as MoNet
|
| 7 |
-
import numpy as np
|
| 8 |
-
from utils.dataset.process import ToTensor, Normalize
|
| 9 |
-
from utils.toolkit import *
|
| 10 |
-
import warnings
|
| 11 |
-
warnings.filterwarnings('ignore')
|
| 12 |
-
|
| 13 |
-
import sys
|
| 14 |
-
sys.path.append(os.path.dirname(__file__))
|
| 15 |
-
|
| 16 |
-
class PromptIQA():
|
| 17 |
-
def __init__(self) -> None:
|
| 18 |
-
pass
|
| 19 |
-
|
| 20 |
-
def load_image(img_path, size=224):
|
| 21 |
-
try:
|
| 22 |
-
d_img = cv2.imread(img_path, cv2.IMREAD_COLOR)
|
| 23 |
-
d_img = cv2.resize(d_img, (size, size), interpolation=cv2.INTER_CUBIC)
|
| 24 |
-
d_img = cv2.cvtColor(d_img, cv2.COLOR_BGR2RGB)
|
| 25 |
-
d_img = np.array(d_img).astype('float32') / 255
|
| 26 |
-
d_img = np.transpose(d_img, (2, 0, 1))
|
| 27 |
-
except:
|
| 28 |
-
print(img_path)
|
| 29 |
-
|
| 30 |
-
return d_img
|
| 31 |
-
|
| 32 |
-
def load_model(pkl_path):
|
| 33 |
-
|
| 34 |
-
model = MoNet.MoNet()
|
| 35 |
-
dict_pkl = {}
|
| 36 |
-
# prompt_num = torch.load(pkl_path, map_location='cpu').get('prompt_num')
|
| 37 |
-
for key, value in torch.load(pkl_path, map_location='cpu')['state_dict'].items():
|
| 38 |
-
dict_pkl[key[7:]] = value
|
| 39 |
-
model.load_state_dict(dict_pkl)
|
| 40 |
-
print('Load Model From ', pkl_path)
|
| 41 |
-
|
| 42 |
-
return model
|
| 43 |
-
|
| 44 |
-
def get_an_img_score(img_path, target):
|
| 45 |
-
transform=torchvision.transforms.Compose([Normalize(0.5, 0.5), ToTensor()])
|
| 46 |
-
values_to_insert = np.array([0.0, 1.0])
|
| 47 |
-
position_to_insert = 0
|
| 48 |
-
target = np.insert(target, position_to_insert, values_to_insert)
|
| 49 |
-
|
| 50 |
-
sample = load_image(img_path)
|
| 51 |
-
samples = {'img': sample, 'gt': target}
|
| 52 |
-
samples = transform(samples)
|
| 53 |
-
|
| 54 |
-
return samples
|
| 55 |
-
|
| 56 |
-
import random
|
| 57 |
-
if __name__ == '__main__':
|
| 58 |
-
pkl_path = "./checkpoints/best_model_five_22.pth.tar"
|
| 59 |
-
model = load_model(pkl_path).cuda()
|
| 60 |
-
model.eval()
|
| 61 |
-
|
| 62 |
-
img_path = '/mnt/storage/PromptIQA_Demo/CSIQ/dst_src'
|
| 63 |
-
|
| 64 |
-
img_tensor, gt_tensor = None, None
|
| 65 |
-
img_list = os.listdir(img_path)
|
| 66 |
-
random.shuffle(img_list)
|
| 67 |
-
for idx, img_name in enumerate(img_list):
|
| 68 |
-
if idx == 10:
|
| 69 |
-
break
|
| 70 |
-
|
| 71 |
-
img_name = os.path.join(img_path, img_name)
|
| 72 |
-
score = np.array(idx / 10)
|
| 73 |
-
samples = get_an_img_score(img_name, score)
|
| 74 |
-
|
| 75 |
-
if img_tensor is None:
|
| 76 |
-
img_tensor = samples['img'].unsqueeze(0)
|
| 77 |
-
gt_tensor = samples['gt'].type(torch.FloatTensor).unsqueeze(0)
|
| 78 |
-
else:
|
| 79 |
-
img_tensor = torch.cat((img_tensor, samples['img'].unsqueeze(0)), dim=0)
|
| 80 |
-
gt_tensor = torch.cat((gt_tensor, samples['gt'].type(torch.FloatTensor).unsqueeze(0)), dim=0)
|
| 81 |
-
|
| 82 |
-
print(img_tensor.shape)
|
| 83 |
-
print(gt_tensor.shape)
|
| 84 |
-
print(gt_tensor)
|
| 85 |
-
|
| 86 |
-
img = img_tensor.squeeze(0).cuda()
|
| 87 |
-
label = gt_tensor.squeeze(0).cuda()
|
| 88 |
-
reverse = False
|
| 89 |
-
if reverse == 2:
|
| 90 |
-
label = torch.rand_like(label[:, -1]).cuda()
|
| 91 |
-
print(label)
|
| 92 |
-
elif reverse == 3:
|
| 93 |
-
print('Total Random')
|
| 94 |
-
label = torch.rand_like(label[:, -1]).cuda()
|
| 95 |
-
img = torch.rand_like(img).cuda()
|
| 96 |
-
else:
|
| 97 |
-
label = label[:, -1].cuda() if not reverse else (1 - label[:, -1].cuda())
|
| 98 |
-
print('input label: ', label)
|
| 99 |
-
model.forward_prompt(img, label.reshape(-1, 1), 'livec')
|
| 100 |
-
|
| 101 |
-
img_name = '/mnt/storage/PromptIQA_Demo/CSIQ/src_imgs/1600.png'
|
| 102 |
-
score = np.array(random.random())
|
| 103 |
-
samples = get_an_img_score(img_name, score)
|
| 104 |
-
|
| 105 |
-
img = samples['img'].unsqueeze(0).cuda()
|
| 106 |
-
print(img.shape)
|
| 107 |
-
pred = model.inference(img, 'livec')
|
| 108 |
-
|
| 109 |
-
print(pred)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
PromptIQA/run_promptIQA.py
CHANGED
|
@@ -3,7 +3,7 @@ import random
|
|
| 3 |
import torchvision
|
| 4 |
import cv2
|
| 5 |
import torch
|
| 6 |
-
from PromptIQA.models import
|
| 7 |
import numpy as np
|
| 8 |
from PromptIQA.utils.dataset.process import ToTensor, Normalize
|
| 9 |
from PromptIQA.utils.toolkit import *
|
|
@@ -14,7 +14,7 @@ import sys
|
|
| 14 |
sys.path.append(os.path.dirname(__file__))
|
| 15 |
|
| 16 |
def load_model(pkl_path):
|
| 17 |
-
model =
|
| 18 |
dict_pkl = {}
|
| 19 |
for key, value in torch.load(pkl_path, map_location='cpu')['state_dict'].items():
|
| 20 |
dict_pkl[key[7:]] = value
|
|
|
|
| 3 |
import torchvision
|
| 4 |
import cv2
|
| 5 |
import torch
|
| 6 |
+
from PromptIQA.models import promptiqa
|
| 7 |
import numpy as np
|
| 8 |
from PromptIQA.utils.dataset.process import ToTensor, Normalize
|
| 9 |
from PromptIQA.utils.toolkit import *
|
|
|
|
| 14 |
sys.path.append(os.path.dirname(__file__))
|
| 15 |
|
| 16 |
def load_model(pkl_path):
|
| 17 |
+
model = promptiqa.PromptIQA()
|
| 18 |
dict_pkl = {}
|
| 19 |
for key, value in torch.load(pkl_path, map_location='cpu')['state_dict'].items():
|
| 20 |
dict_pkl[key[7:]] = value
|
PromptIQA/t.py
DELETED
|
@@ -1,2 +0,0 @@
|
|
| 1 |
-
a = "(1+1)**(2**2)"
|
| 2 |
-
print(eval(a))
|
|
|
|
|
|
|
|
|
PromptIQA/test.py
DELETED
|
@@ -1,429 +0,0 @@
|
|
| 1 |
-
import sys
|
| 2 |
-
|
| 3 |
-
from utils import log_writer
|
| 4 |
-
|
| 5 |
-
import argparse
|
| 6 |
-
import builtins
|
| 7 |
-
import os
|
| 8 |
-
import random
|
| 9 |
-
import shutil
|
| 10 |
-
import time
|
| 11 |
-
|
| 12 |
-
import torch
|
| 13 |
-
import torch.distributed as dist
|
| 14 |
-
import torch.multiprocessing as mp
|
| 15 |
-
import torch.nn as nn
|
| 16 |
-
import torch.nn.parallel
|
| 17 |
-
import torch.optim
|
| 18 |
-
import torch.utils.data
|
| 19 |
-
import torch.utils.data.distributed
|
| 20 |
-
# from models import monet as MoNet
|
| 21 |
-
from torch.utils.data import ConcatDataset
|
| 22 |
-
from utils.dataset import data_loader
|
| 23 |
-
|
| 24 |
-
from utils.toolkit import *
|
| 25 |
-
|
| 26 |
-
loger_path = None
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
def init(config):
|
| 30 |
-
global loger_path
|
| 31 |
-
if config.dist_url == "env://" and config.world_size == -1:
|
| 32 |
-
config.world_size = int(os.environ["WORLD_SIZE"])
|
| 33 |
-
|
| 34 |
-
config.distributed = config.world_size > 1 or config.multiprocessing_distributed
|
| 35 |
-
|
| 36 |
-
print("config.distributed", config.distributed)
|
| 37 |
-
|
| 38 |
-
loger_path = os.path.join(config.save_path, "inference_log")
|
| 39 |
-
if not os.path.isdir(loger_path):
|
| 40 |
-
os.makedirs(loger_path)
|
| 41 |
-
|
| 42 |
-
print("----------------------------------")
|
| 43 |
-
print(
|
| 44 |
-
"Begin Time: ", time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(time.time()))
|
| 45 |
-
)
|
| 46 |
-
printArgs(config, loger_path)
|
| 47 |
-
# os.environ["CUDA_VISIBLE_DEVICES"] = '2,3,4,5,6,7'
|
| 48 |
-
os.environ["CUDA_VISIBLE_DEVICES"] = '0'
|
| 49 |
-
# os.environ["CUDA_VISIBLE_DEVICES"] = '0,1,2,3,4,5'
|
| 50 |
-
# os.environ["CUDA_VISIBLE_DEVICES"] = '6,7'
|
| 51 |
-
# os.environ["CUDA_VISIBLE_DEVICES"] = '6'
|
| 52 |
-
# setup_seed(config.seed)
|
| 53 |
-
|
| 54 |
-
|
| 55 |
-
def main(config):
|
| 56 |
-
init(config)
|
| 57 |
-
ngpus_per_node = torch.cuda.device_count()
|
| 58 |
-
if config.multiprocessing_distributed:
|
| 59 |
-
config.world_size = ngpus_per_node * config.world_size
|
| 60 |
-
|
| 61 |
-
print(config.world_size, ngpus_per_node, ngpus_per_node)
|
| 62 |
-
mp.spawn(main_worker, nprocs=ngpus_per_node, args=(ngpus_per_node, config))
|
| 63 |
-
else:
|
| 64 |
-
# Simply call main_worker function
|
| 65 |
-
main_worker(config.gpu, ngpus_per_node, config)
|
| 66 |
-
|
| 67 |
-
print("End Time: ", time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(time.time())))
|
| 68 |
-
|
| 69 |
-
|
| 70 |
-
@torch.no_grad()
|
| 71 |
-
def gather_together(data): # 封装成一个函数,,用于收集各个gpu上的data数据,并返回一个list
|
| 72 |
-
dist.barrier()
|
| 73 |
-
world_size = dist.get_world_size()
|
| 74 |
-
gather_data = [None for _ in range(world_size)]
|
| 75 |
-
dist.all_gather_object(gather_data, data)
|
| 76 |
-
return gather_data
|
| 77 |
-
|
| 78 |
-
import importlib.util
|
| 79 |
-
def main_worker(gpu, ngpus_per_node, args):
|
| 80 |
-
models_path = os.path.join(args.save_path, "training_files", 'models', 'monet.py')
|
| 81 |
-
spec = importlib.util.spec_from_file_location("monet_module", models_path)
|
| 82 |
-
monet_module = importlib.util.module_from_spec(spec)
|
| 83 |
-
spec.loader.exec_module(monet_module)
|
| 84 |
-
MoNet = monet_module
|
| 85 |
-
|
| 86 |
-
loger_path = os.path.join(args.save_path, "inference_log")
|
| 87 |
-
if gpu == 0:
|
| 88 |
-
sys.stdout = log_writer.Logger(os.path.join(loger_path, f"inference_log_{args.prompt_type}_{args.reverse}.log"))
|
| 89 |
-
args.gpu = gpu
|
| 90 |
-
|
| 91 |
-
# suppress printing if not master
|
| 92 |
-
if args.multiprocessing_distributed and args.gpu != 0:
|
| 93 |
-
def print_pass(*args):
|
| 94 |
-
pass
|
| 95 |
-
|
| 96 |
-
builtins.print = print_pass
|
| 97 |
-
|
| 98 |
-
if args.gpu is not None:
|
| 99 |
-
print("Use GPU: {} for testing".format(args.gpu))
|
| 100 |
-
|
| 101 |
-
if args.distributed:
|
| 102 |
-
if args.dist_url == "env://" and args.rank == -1:
|
| 103 |
-
args.rank = int(os.environ["RANK"])
|
| 104 |
-
if args.multiprocessing_distributed:
|
| 105 |
-
args.rank = args.rank * ngpus_per_node + gpu
|
| 106 |
-
dist.init_process_group(
|
| 107 |
-
backend=args.dist_backend,
|
| 108 |
-
init_method=args.dist_url,
|
| 109 |
-
world_size=args.world_size,
|
| 110 |
-
rank=args.rank,
|
| 111 |
-
)
|
| 112 |
-
|
| 113 |
-
# create model
|
| 114 |
-
model = MoNet.MoNet()
|
| 115 |
-
dict_pkl = {}
|
| 116 |
-
prompt_num = torch.load(args.pkl_path, map_location='cpu').get('prompt_num')
|
| 117 |
-
for key, value in torch.load(args.pkl_path, map_location='cpu')['state_dict'].items():
|
| 118 |
-
dict_pkl[key[7:]] = value
|
| 119 |
-
model.load_state_dict(dict_pkl)
|
| 120 |
-
print('Load Model From ', args.pkl_path)
|
| 121 |
-
|
| 122 |
-
if args.distributed:
|
| 123 |
-
if args.gpu is not None:
|
| 124 |
-
torch.cuda.set_device(args.gpu)
|
| 125 |
-
model.cuda(args.gpu)
|
| 126 |
-
args.batch_size = int(args.batch_size / ngpus_per_node)
|
| 127 |
-
args.workers = int((args.workers + ngpus_per_node - 1) / ngpus_per_node)
|
| 128 |
-
model = torch.nn.parallel.DistributedDataParallel(
|
| 129 |
-
model, device_ids=[args.gpu]
|
| 130 |
-
)
|
| 131 |
-
print("Model Distribute.")
|
| 132 |
-
else:
|
| 133 |
-
model.cuda()
|
| 134 |
-
model = torch.nn.parallel.DistributedDataParallel(model)
|
| 135 |
-
|
| 136 |
-
if prompt_num is None:
|
| 137 |
-
prompt_num = args.batch_size - 1
|
| 138 |
-
prompt_num = 10
|
| 139 |
-
print('prompt_num', prompt_num)
|
| 140 |
-
|
| 141 |
-
test_prompt_list, test_data_list = {}, []
|
| 142 |
-
# fix_prompt = None
|
| 143 |
-
for dataset in args.dataset:
|
| 144 |
-
print('---Load ', dataset)
|
| 145 |
-
path, train_index, test_index = get_data(dataset=dataset, split_seed=args.seed)
|
| 146 |
-
# if dataset == 'spaq' and False:
|
| 147 |
-
if dataset == 'spaq':
|
| 148 |
-
for column in range(2, 8):
|
| 149 |
-
print('sapq column train', column)
|
| 150 |
-
test_dataset = data_loader.Data_Loader(args.batch_size, dataset, path, test_index, istrain=False, column=column)
|
| 151 |
-
test_data_list.append(test_dataset.get_samples())
|
| 152 |
-
|
| 153 |
-
train_dataset = data_loader.Data_Loader(args.batch_size, dataset, path, train_index, istrain=False, column=column)
|
| 154 |
-
test_prompt_list[dataset+f'_{column}'] = train_dataset.get_prompt(prompt_num, args.prompt_type)
|
| 155 |
-
else:
|
| 156 |
-
test_dataset = data_loader.Data_Loader(args.batch_size, dataset, path, test_index, istrain=False, types=args.types)
|
| 157 |
-
test_data_list.append(test_dataset.get_samples())
|
| 158 |
-
|
| 159 |
-
train_dataset = data_loader.Data_Loader(args.batch_size, dataset, path, train_index, istrain=False, types=args.types)
|
| 160 |
-
test_prompt_list[dataset] = train_dataset.get_prompt(prompt_num, args.prompt_type)
|
| 161 |
-
print('args.prompt_type', args.prompt_type)
|
| 162 |
-
|
| 163 |
-
combined_test_samples = ConcatDataset(test_data_list)
|
| 164 |
-
print("test_dataset", len(combined_test_samples))
|
| 165 |
-
test_sampler = torch.utils.data.distributed.DistributedSampler(combined_test_samples)
|
| 166 |
-
|
| 167 |
-
test_loader = torch.utils.data.DataLoader(
|
| 168 |
-
combined_test_samples,
|
| 169 |
-
batch_size=1,
|
| 170 |
-
shuffle=(test_sampler is None),
|
| 171 |
-
num_workers=args.workers,
|
| 172 |
-
sampler=test_sampler,
|
| 173 |
-
drop_last=False,
|
| 174 |
-
pin_memory=True,
|
| 175 |
-
)
|
| 176 |
-
|
| 177 |
-
if args.distributed:
|
| 178 |
-
test_sampler.set_epoch(0)
|
| 179 |
-
|
| 180 |
-
for idxsa in range(1):
|
| 181 |
-
test_srocc, test_plcc, pred_scores, gt_scores, path = test(
|
| 182 |
-
test_loader, model, test_prompt_list, reverse=args.reverse
|
| 183 |
-
)
|
| 184 |
-
print('gt_scores', len(pred_scores), len(gt_scores))
|
| 185 |
-
print('Summary---')
|
| 186 |
-
|
| 187 |
-
gt_scores = gather_together(gt_scores) # 进行汇总,得到一个list
|
| 188 |
-
pred_scores = gather_together(pred_scores) # 进行汇总,得到一个list
|
| 189 |
-
|
| 190 |
-
gt_score_dict, pred_score_dict = {}, {}
|
| 191 |
-
for sublist in gt_scores:
|
| 192 |
-
for k, v in sublist.items():
|
| 193 |
-
if k not in gt_score_dict:
|
| 194 |
-
gt_score_dict[k] = v
|
| 195 |
-
else:
|
| 196 |
-
gt_score_dict[k] = gt_score_dict[k] + v
|
| 197 |
-
|
| 198 |
-
for sublist in pred_scores:
|
| 199 |
-
for k, v in sublist.items():
|
| 200 |
-
if k not in pred_score_dict:
|
| 201 |
-
pred_score_dict[k] = v
|
| 202 |
-
else:
|
| 203 |
-
pred_score_dict[k] = pred_score_dict[k] + v
|
| 204 |
-
|
| 205 |
-
gt_score_dict = dict(sorted(gt_score_dict.items()))
|
| 206 |
-
test_srocc, test_plcc = 0, 0
|
| 207 |
-
for k, v in gt_score_dict.items():
|
| 208 |
-
test_srocc_, test_plcc_ = cal_srocc_plcc(gt_score_dict[k], pred_score_dict[k])
|
| 209 |
-
print('\t{} Test SROCC: {}, PLCC: {}'.format(k, round(test_srocc_, 4), round(test_plcc_, 4)))
|
| 210 |
-
# print('Pred: ', pred_score_dict[k][:10])
|
| 211 |
-
# print('GT: ', gt_score_dict[k][:10])
|
| 212 |
-
# print('-----')
|
| 213 |
-
|
| 214 |
-
with open('{}_{}.csv'.format(idxsa, k), 'w') as f:
|
| 215 |
-
for i, j in zip(gt_score_dict[k], pred_score_dict[k]):
|
| 216 |
-
f.write('{},{}\n'.format(i, j))
|
| 217 |
-
test_srocc += test_srocc_
|
| 218 |
-
test_plcc += test_plcc_
|
| 219 |
-
|
| 220 |
-
|
| 221 |
-
def test(test_loader, MoNet, promt_data_loader, reverse=False):
|
| 222 |
-
"""Training"""
|
| 223 |
-
pred_scores = {}
|
| 224 |
-
gt_scores = {}
|
| 225 |
-
path = []
|
| 226 |
-
|
| 227 |
-
batch_time = AverageMeter("Time", ":6.3f")
|
| 228 |
-
srocc = AverageMeter("SROCC", ":6.2f")
|
| 229 |
-
plcc = AverageMeter("PLCC", ":6.2f")
|
| 230 |
-
progress = ProgressMeter(
|
| 231 |
-
len(test_loader),
|
| 232 |
-
[batch_time, srocc, plcc],
|
| 233 |
-
prefix="Testing ",
|
| 234 |
-
)
|
| 235 |
-
|
| 236 |
-
print('reverse ----', reverse)
|
| 237 |
-
MoNet.train(False)
|
| 238 |
-
with torch.no_grad():
|
| 239 |
-
for index, (img_or, label_or, paths, dataset_type) in enumerate(test_loader):
|
| 240 |
-
# print(dataset_type)
|
| 241 |
-
t = time.time()
|
| 242 |
-
dataset_type = dataset_type[0]
|
| 243 |
-
|
| 244 |
-
has_prompt = False
|
| 245 |
-
if hasattr(MoNet.module, 'check_prompt'):
|
| 246 |
-
has_prompt = MoNet.module.check_prompt(dataset_type)
|
| 247 |
-
|
| 248 |
-
if not has_prompt:
|
| 249 |
-
print('Load Prompt For ', dataset_type)
|
| 250 |
-
prompt_dataset = promt_data_loader[dataset_type]
|
| 251 |
-
for img, label in prompt_dataset:
|
| 252 |
-
img = img.squeeze(0).cuda()
|
| 253 |
-
label = label.squeeze(0).cuda()
|
| 254 |
-
if reverse == 2:
|
| 255 |
-
# label = torch.tensor([random.random() for i in range(len(label[:, -1]))]).cuda()
|
| 256 |
-
#
|
| 257 |
-
label = torch.rand_like(label[:, -1]).cuda()
|
| 258 |
-
print(label)
|
| 259 |
-
elif reverse == 3:
|
| 260 |
-
print('Total Random')
|
| 261 |
-
label = torch.rand_like(label[:, -1]).cuda()
|
| 262 |
-
img = torch.rand_like(img).cuda()
|
| 263 |
-
else:
|
| 264 |
-
label = label[:, -1].cuda() if not reverse else (1 - label[:, -1].cuda())
|
| 265 |
-
MoNet.module.forward_prompt(img, label.reshape(-1, 1), dataset_type)
|
| 266 |
-
|
| 267 |
-
img = img_or.squeeze(0).cuda()
|
| 268 |
-
label = label_or.squeeze(0).cuda()[:, 2]
|
| 269 |
-
|
| 270 |
-
# print(img.shape)
|
| 271 |
-
|
| 272 |
-
pred = MoNet.module.inference(img, dataset_type)
|
| 273 |
-
|
| 274 |
-
if dataset_type not in pred_scores:
|
| 275 |
-
pred_scores[dataset_type] = []
|
| 276 |
-
|
| 277 |
-
if dataset_type not in gt_scores:
|
| 278 |
-
gt_scores[dataset_type] = []
|
| 279 |
-
|
| 280 |
-
pred_scores[dataset_type] = pred_scores[dataset_type] + pred.cpu().tolist()
|
| 281 |
-
gt_scores[dataset_type] = gt_scores[dataset_type] + label.cpu().tolist()
|
| 282 |
-
path = path + list(paths)
|
| 283 |
-
|
| 284 |
-
batch_time.update(time.time() - t)
|
| 285 |
-
|
| 286 |
-
if index % 100 == 0:
|
| 287 |
-
for k, v in pred_scores.items():
|
| 288 |
-
test_srocc, test_plcc = cal_srocc_plcc(pred_scores[k], gt_scores[k])
|
| 289 |
-
# print('\t{}, SROCC: {}, PLCC: {}'.format(k, round(test_srocc, 4), round(test_plcc, 4)))
|
| 290 |
-
srocc.update(test_srocc)
|
| 291 |
-
plcc.update(test_plcc)
|
| 292 |
-
|
| 293 |
-
progress.display(index)
|
| 294 |
-
|
| 295 |
-
MoNet.module.clear()
|
| 296 |
-
# MoNet.train(True)
|
| 297 |
-
return 'test_srocc', 'test_plcc', pred_scores, gt_scores, path
|
| 298 |
-
|
| 299 |
-
if __name__ == "__main__":
|
| 300 |
-
parser = argparse.ArgumentParser()
|
| 301 |
-
parser.add_argument(
|
| 302 |
-
"--seed",
|
| 303 |
-
dest="seed",
|
| 304 |
-
type=int,
|
| 305 |
-
default=570908,
|
| 306 |
-
help="Random seeds for result reproduction.",
|
| 307 |
-
)
|
| 308 |
-
|
| 309 |
-
parser.add_argument(
|
| 310 |
-
"--mal_num",
|
| 311 |
-
dest="mal_num",
|
| 312 |
-
type=int,
|
| 313 |
-
default=2,
|
| 314 |
-
help="The number of the MAL modules.",
|
| 315 |
-
)
|
| 316 |
-
|
| 317 |
-
# data related
|
| 318 |
-
parser.add_argument(
|
| 319 |
-
"--dataset",
|
| 320 |
-
dest="dataset",
|
| 321 |
-
nargs='+', default=None,
|
| 322 |
-
help="Support datasets: livec|koniq10k|bid|spaq",
|
| 323 |
-
)
|
| 324 |
-
|
| 325 |
-
# training related
|
| 326 |
-
parser.add_argument(
|
| 327 |
-
"--queue_ratio",
|
| 328 |
-
dest="queue_ratio",
|
| 329 |
-
type=float,
|
| 330 |
-
default=0.6,
|
| 331 |
-
help="Ratio of queue length used in GC loss to training set length.",
|
| 332 |
-
)
|
| 333 |
-
|
| 334 |
-
parser.add_argument(
|
| 335 |
-
"--loss",
|
| 336 |
-
dest="loss",
|
| 337 |
-
type=str,
|
| 338 |
-
default="MSE",
|
| 339 |
-
help="Loss function to use. Support losses: GC|MAE|MSE.",
|
| 340 |
-
)
|
| 341 |
-
|
| 342 |
-
parser.add_argument(
|
| 343 |
-
"--lr", dest="lr", type=float, default=1e-5, help="Learning rate"
|
| 344 |
-
)
|
| 345 |
-
|
| 346 |
-
parser.add_argument(
|
| 347 |
-
"--weight_decay",
|
| 348 |
-
dest="weight_decay",
|
| 349 |
-
type=float,
|
| 350 |
-
default=1e-5,
|
| 351 |
-
help="Weight decay",
|
| 352 |
-
)
|
| 353 |
-
parser.add_argument(
|
| 354 |
-
"--batch_size", dest="batch_size", type=int, default=11, help="Batch size"
|
| 355 |
-
)
|
| 356 |
-
parser.add_argument(
|
| 357 |
-
"--epochs", dest="epochs", type=int, default=50, help="Epochs for training"
|
| 358 |
-
)
|
| 359 |
-
parser.add_argument(
|
| 360 |
-
"--T_max",
|
| 361 |
-
dest="T_max",
|
| 362 |
-
type=int,
|
| 363 |
-
default=50,
|
| 364 |
-
help="Hyper-parameter for CosineAnnealingLR",
|
| 365 |
-
)
|
| 366 |
-
parser.add_argument(
|
| 367 |
-
"--eta_min",
|
| 368 |
-
dest="eta_min",
|
| 369 |
-
type=int,
|
| 370 |
-
default=0,
|
| 371 |
-
help="Hyper-parameter for CosineAnnealingLR",
|
| 372 |
-
)
|
| 373 |
-
|
| 374 |
-
parser.add_argument(
|
| 375 |
-
"-j",
|
| 376 |
-
"--workers",
|
| 377 |
-
default=32,
|
| 378 |
-
type=int,
|
| 379 |
-
metavar="N",
|
| 380 |
-
help="number of data loading workers (default: 32)",
|
| 381 |
-
)
|
| 382 |
-
|
| 383 |
-
# result related
|
| 384 |
-
parser.add_argument(
|
| 385 |
-
"--save_path",
|
| 386 |
-
dest="save_path",
|
| 387 |
-
type=str,
|
| 388 |
-
default="./save_logs/Matrix_Comparation_Koniq_bs_25",
|
| 389 |
-
help="The path where the model and logs will be saved.",
|
| 390 |
-
)
|
| 391 |
-
|
| 392 |
-
parser.add_argument(
|
| 393 |
-
"--world-size",
|
| 394 |
-
default=-1,
|
| 395 |
-
type=int,
|
| 396 |
-
help="number of nodes for distributed training",
|
| 397 |
-
)
|
| 398 |
-
parser.add_argument(
|
| 399 |
-
"--rank", default=-1, type=int, help="node rank for distributed training"
|
| 400 |
-
)
|
| 401 |
-
parser.add_argument(
|
| 402 |
-
"--dist-url",
|
| 403 |
-
default="tcp://224.66.41.62:23456",
|
| 404 |
-
type=str,
|
| 405 |
-
help="url used to set up distributed training",
|
| 406 |
-
)
|
| 407 |
-
parser.add_argument(
|
| 408 |
-
"--dist-backend", default="nccl", type=str, help="distributed backend"
|
| 409 |
-
)
|
| 410 |
-
parser.add_argument(
|
| 411 |
-
"--multiprocessing-distributed",
|
| 412 |
-
action="store_true",
|
| 413 |
-
help="Use multi-processing distributed training to launch "
|
| 414 |
-
"N processes per node, which has N GPUs. This is the "
|
| 415 |
-
"fastest way to use PyTorch for either single node or "
|
| 416 |
-
"multi node data parallel training",
|
| 417 |
-
)
|
| 418 |
-
|
| 419 |
-
parser.add_argument("--gpu", default=None, type=int, help="GPU id to use.")
|
| 420 |
-
parser.add_argument("--pkl_path", required=True, type=str)
|
| 421 |
-
parser.add_argument("--prompt_type", required=True, type=str)
|
| 422 |
-
parser.add_argument("--reverse", required=True, type=int)
|
| 423 |
-
parser.add_argument("--types", default='SSIM', type=str)
|
| 424 |
-
|
| 425 |
-
config = parser.parse_args()
|
| 426 |
-
|
| 427 |
-
config.save_path = os.path.dirname(config.pkl_path)
|
| 428 |
-
|
| 429 |
-
main(config)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
PromptIQA/test.sh
DELETED
|
@@ -1,9 +0,0 @@
|
|
| 1 |
-
# python test.py --dist-url 'tcp://localhost:10055' --dataset spaq tid2013 livec bid spaq flive --batch_size 50 --prompt_type fix --multiprocessing-distributed --world-size 1 --rank 0 --reverse 0 --pkl_path /disk1/chenzewen/OurIdeas/GIQA/GIQA_2024/FourTask/N_F_A_U_RandomScale_MAE_loaderDebug_Rate95/best_model_five_52.pth.tar
|
| 2 |
-
# python test.py --dist-url 'tcp://localhost:12755' --dataset csiq --batch_size 50 --prompt_type fix --multiprocessing-distributed --world-size 1 --rank 0 --reverse 3 --seed 2024 --pkl_path /disk1/chenzewen/OurIdeas/GIQA/GIQA_2024/Training_log/FourTask/N_F_A_U_RandomScale_MAE_loaderDebug_Rate95/best_model_five_52.pth.tar
|
| 3 |
-
python test.py --dist-url 'tcp://localhost:12755' --dataset livec bid csiq --batch_size 50 --prompt_type random --multiprocessing-distributed --world-size 1 --rank 0 --reverse 2 --seed 2026 --pkl_path /disk1/chenzewen/OurIdeas/GIQA/GIQA_2024/Formal/PromptIQA_2026/best_model_five_92.pth.tar
|
| 4 |
-
# reverse 0 no, 1 yes, 2 random
|
| 5 |
-
|
| 6 |
-
python test.py --dist-url 'tcp://localhost:12755' --dataset tid2013_other --batch_size 50 --prompt_type random --multiprocessing-distributed --world-size 1 --rank 0 --reverse 2 --seed 2026 --pkl_path /disk1/chenzewen/OurIdeas/GIQA/GIQA_2024/Formal/PromptIQA_2026/best_model_five_92.pth.tar --types 'SSIM'
|
| 7 |
-
|
| 8 |
-
|
| 9 |
-
CUDA_VISIBLE_DEVICES="0" python test.py --dist-url 'tcp://localhost:12755' --dataset tid2013_other --batch_size 50 --prompt_type random --multiprocessing-distributed --world-size 1 --rank 0 --reverse 2 --seed 2024 --pkl_path /disk1/chenzewen/OurIdeas/GIQA/GIQA_2024/Publication/PromptIQA_2024_WO_Norm_Score/best_model_five_22.pth.tar --types 'SSIM'
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
best_model.pth.tar
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:993555b9efaeae660d2dd6f4056f13c6957628ca592a2ce74ff2e8eb5a4a2280
|
| 3 |
+
size 1272842308
|
get_examplt.py
DELETED
|
@@ -1,27 +0,0 @@
|
|
| 1 |
-
import os
|
| 2 |
-
from copy import deepcopy
|
| 3 |
-
|
| 4 |
-
|
| 5 |
-
isp_json = []
|
| 6 |
-
path = './Examples'
|
| 7 |
-
for img_dir in sorted(os.listdir(path)):
|
| 8 |
-
if os.path.isdir(os.path.join(path, img_dir)):
|
| 9 |
-
ISPP = os.path.join(path, img_dir, 'ISPP')
|
| 10 |
-
|
| 11 |
-
ispp = {}
|
| 12 |
-
ispp['Example_id'] = img_dir
|
| 13 |
-
ispp['ISPP'] = []
|
| 14 |
-
img_list = []
|
| 15 |
-
for idx, img in enumerate(sorted(os.listdir(ISPP))):
|
| 16 |
-
ispp['ISPP'].append([os.path.join(ISPP, img), idx / 10 if '1' in img_dir else 1 - idx / 10])
|
| 17 |
-
|
| 18 |
-
for file in os.listdir(os.path.join(path, img_dir)):
|
| 19 |
-
if os.path.isfile(os.path.join(path, img_dir, file)):
|
| 20 |
-
img_list.append(file)
|
| 21 |
-
ispp['Image'] = [os.path.join(path, img_dir, file), 7]
|
| 22 |
-
ispp['Remark'] = []
|
| 23 |
-
isp_json.append(deepcopy(ispp))
|
| 24 |
-
|
| 25 |
-
with open('example2.json', 'w') as f:
|
| 26 |
-
import json
|
| 27 |
-
json.dump(isp_json, f, indent=4)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|