MohmedAnik commited on
Commit
06f2523
·
verified ·
1 Parent(s): 6ebe9b0

Upload 11 files

Browse files
Files changed (12) hide show
  1. .gitattributes +1 -0
  2. app.py +78 -0
  3. examples/car.jpg +0 -0
  4. examples/iMAC.jpg +0 -0
  5. examples/pig.jpg +0 -0
  6. examples/statue.png +3 -0
  7. gitattributes +35 -0
  8. inference.py +49 -0
  9. paths.py +4 -0
  10. requirements.txt +10 -0
  11. utils.py +304 -0
  12. vision_tower.py +161 -0
.gitattributes ADDED
@@ -0,0 +1 @@
 
 
1
+ examples/statue.png filter=lfs diff=lfs merge=lfs -text
app.py ADDED
@@ -0,0 +1,78 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from paths import *
3
+ import os
4
+ from vision_tower import DINOv2_MLP
5
+ from transformers import AutoImageProcessor
6
+ import torch
7
+ from inference import *
8
+ from utils import *
9
+
10
+ from huggingface_hub import hf_hub_download
11
+ ckpt_path = hf_hub_download(repo_id="Viglong/Orient-Anything", filename="ronormsigma1/dino_weight.pt", repo_type="model", cache_dir='./', resume_download=True)
12
+ print(ckpt_path)
13
+
14
+ save_path = './'
15
+ device = 'cpu'
16
+ dino = DINOv2_MLP(
17
+ dino_mode = 'large',
18
+ in_dim = 1024,
19
+ out_dim = 360+180+360+2,
20
+ evaluate = True,
21
+ mask_dino = False,
22
+ frozen_back = False
23
+ )
24
+
25
+ dino.eval()
26
+ print('model create')
27
+ dino.load_state_dict(torch.load(ckpt_path, map_location='cpu'))
28
+ dino = dino.to(device)
29
+ print('weight loaded')
30
+ val_preprocess = AutoImageProcessor.from_pretrained(DINO_LARGE, cache_dir='./')
31
+
32
+ def infer_func(img, do_rm_bkg=True, do_infer_aug=False):
33
+ origin_img = Image.fromarray(img)
34
+ if do_infer_aug:
35
+ rm_bkg_img = background_preprocess(origin_img, True)
36
+ angles = get_3angle_infer_aug(origin_img, rm_bkg_img, dino, val_preprocess, device)
37
+ else:
38
+ rm_bkg_img = background_preprocess(origin_img, do_rm_bkg)
39
+ angles = get_3angle(rm_bkg_img, dino, val_preprocess, device)
40
+
41
+ phi = np.radians(angles[0])
42
+ theta = np.radians(angles[1])
43
+ gamma = angles[2]
44
+ confidence = float(angles[3])
45
+ if confidence > 0.5:
46
+ render_axis = render_3D_axis(phi, theta, gamma)
47
+ res_img = overlay_images_with_scaling(render_axis, rm_bkg_img)
48
+ else:
49
+ res_img = img
50
+
51
+ # axis_model = "axis.obj"
52
+ return [res_img, round(float(angles[0]), 2), round(float(angles[1]), 2), round(float(angles[2]), 2), round(float(angles[3]), 2)]
53
+
54
+ example_files = os.listdir('examples')
55
+ example_files.sort()
56
+ example_files = [[os.path.join('examples', filename), None, None] for filename in example_files]
57
+ print(example_files)
58
+ server = gr.Interface(
59
+ flagging_mode='never',
60
+ fn=infer_func,
61
+ examples=example_files,
62
+ cache_examples=False,
63
+ inputs=[
64
+ gr.Image(height=512, width=512, label="upload your image"),
65
+ gr.Checkbox(label="Remove Background", value=True),
66
+ gr.Checkbox(label="Inference time augmentation", value=False)
67
+ ],
68
+ outputs=[
69
+ gr.Image(height=512, width=512, label="result image"),
70
+ # gr.Model3D(clear_color=[0.0, 0.0, 0.0, 0.0], label="3D Model"),
71
+ gr.Textbox(lines=1, label='Azimuth(0~360°) represents the position of the viewer in the xy plane'),
72
+ gr.Textbox(lines=1, label='Polar(-90~90°) indicating the height at which the viewer is located'),
73
+ gr.Textbox(lines=1, label='Rotation(-90~90°) represents the angle of rotation of the viewer'),
74
+ gr.Textbox(lines=1, label='Confidence(0~1) indicating whether the object has a meaningful orientation')
75
+ ]
76
+ )
77
+
78
+ server.launch()
examples/car.jpg ADDED
examples/iMAC.jpg ADDED
examples/pig.jpg ADDED
examples/statue.png ADDED

Git LFS Details

  • SHA256: bc88dd340ed4a6177207ecc649654b2c12ad82e949b4acdebf49ea94ff7597a5
  • Pointer size: 131 Bytes
  • Size of remote file: 330 kB
gitattributes ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tar filter=lfs diff=lfs merge=lfs -text
29
+ *.tflite filter=lfs diff=lfs merge=lfs -text
30
+ *.tgz filter=lfs diff=lfs merge=lfs -text
31
+ *.wasm filter=lfs diff=lfs merge=lfs -text
32
+ *.xz filter=lfs diff=lfs merge=lfs -text
33
+ *.zip filter=lfs diff=lfs merge=lfs -text
34
+ *.zst filter=lfs diff=lfs merge=lfs -text
35
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
inference.py ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from PIL import Image
3
+ from utils import *
4
+ import torch.nn.functional as F
5
+ import numpy as np
6
+
7
+ def get_3angle(image, dino, val_preprocess, device):
8
+
9
+ # image = Image.open(image_path).convert('RGB')
10
+ image_inputs = val_preprocess(images = image)
11
+ image_inputs['pixel_values'] = torch.from_numpy(np.array(image_inputs['pixel_values'])).to(device)
12
+ with torch.no_grad():
13
+ dino_pred = dino(image_inputs)
14
+
15
+ gaus_ax_pred = torch.argmax(dino_pred[:, 0:360], dim=-1)
16
+ gaus_pl_pred = torch.argmax(dino_pred[:, 360:360+180], dim=-1)
17
+ gaus_ro_pred = torch.argmax(dino_pred[:, 360+180:360+180+360], dim=-1)
18
+ confidence = F.softmax(dino_pred[:, -2:], dim=-1)[0][0]
19
+ angles = torch.zeros(4)
20
+ angles[0] = gaus_ax_pred
21
+ angles[1] = gaus_pl_pred - 90
22
+ angles[2] = gaus_ro_pred - 180
23
+ angles[3] = confidence
24
+ return angles
25
+
26
+ def get_3angle_infer_aug(origin_img, rm_bkg_img, dino, val_preprocess, device):
27
+
28
+ # image = Image.open(image_path).convert('RGB')
29
+ image = get_crop_images(origin_img, num=3) + get_crop_images(rm_bkg_img, num=3)
30
+ image_inputs = val_preprocess(images = image)
31
+ image_inputs['pixel_values'] = torch.from_numpy(np.array(image_inputs['pixel_values'])).to(device)
32
+ with torch.no_grad():
33
+ dino_pred = dino(image_inputs)
34
+
35
+ gaus_ax_pred = torch.argmax(dino_pred[:, 0:360], dim=-1).to(torch.float32)
36
+ gaus_pl_pred = torch.argmax(dino_pred[:, 360:360+180], dim=-1).to(torch.float32)
37
+ gaus_ro_pred = torch.argmax(dino_pred[:, 360+180:360+180+360], dim=-1).to(torch.float32)
38
+
39
+ gaus_ax_pred = remove_outliers_and_average_circular(gaus_ax_pred)
40
+ gaus_pl_pred = remove_outliers_and_average(gaus_pl_pred)
41
+ gaus_ro_pred = remove_outliers_and_average(gaus_ro_pred)
42
+
43
+ confidence = torch.mean(F.softmax(dino_pred[:, -2:], dim=-1), dim=0)[0]
44
+ angles = torch.zeros(4)
45
+ angles[0] = gaus_ax_pred
46
+ angles[1] = gaus_pl_pred - 90
47
+ angles[2] = gaus_ro_pred - 180
48
+ angles[3] = confidence
49
+ return angles
paths.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ DINO_SMALL = "facebook/dinov2-small"
2
+ DINO_BASE = "facebook/dinov2-base"
3
+ DINO_LARGE = "facebook/dinov2-large"
4
+ DINO_GIANT = "facebook/dinov2-giant"
requirements.txt ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ torch==2.2.1
2
+ transformers==4.38
3
+ matplotlib
4
+ pillow==10.2.0
5
+ huggingface-hub==0.26.5
6
+ gradio==5.9.0
7
+ numpy==1.26.4
8
+ onnxruntime
9
+ rembg
10
+ pydantic==2.10.6
utils.py ADDED
@@ -0,0 +1,304 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import rembg
2
+ import random
3
+ import torch
4
+ import numpy as np
5
+ from PIL import Image, ImageOps
6
+ import PIL
7
+ from typing import Any
8
+ import matplotlib.pyplot as plt
9
+ import io
10
+
11
+ def resize_foreground(
12
+ image: Image,
13
+ ratio: float,
14
+ ) -> Image:
15
+ image = np.array(image)
16
+ assert image.shape[-1] == 4
17
+ alpha = np.where(image[..., 3] > 0)
18
+ y1, y2, x1, x2 = (
19
+ alpha[0].min(),
20
+ alpha[0].max(),
21
+ alpha[1].min(),
22
+ alpha[1].max(),
23
+ )
24
+ # crop the foreground
25
+ fg = image[y1:y2, x1:x2]
26
+ # pad to square
27
+ size = max(fg.shape[0], fg.shape[1])
28
+ ph0, pw0 = (size - fg.shape[0]) // 2, (size - fg.shape[1]) // 2
29
+ ph1, pw1 = size - fg.shape[0] - ph0, size - fg.shape[1] - pw0
30
+ new_image = np.pad(
31
+ fg,
32
+ ((ph0, ph1), (pw0, pw1), (0, 0)),
33
+ mode="constant",
34
+ constant_values=((0, 0), (0, 0), (0, 0)),
35
+ )
36
+
37
+ # compute padding according to the ratio
38
+ new_size = int(new_image.shape[0] / ratio)
39
+ # pad to size, double side
40
+ ph0, pw0 = (new_size - size) // 2, (new_size - size) // 2
41
+ ph1, pw1 = new_size - size - ph0, new_size - size - pw0
42
+ new_image = np.pad(
43
+ new_image,
44
+ ((ph0, ph1), (pw0, pw1), (0, 0)),
45
+ mode="constant",
46
+ constant_values=((0, 0), (0, 0), (0, 0)),
47
+ )
48
+ new_image = Image.fromarray(new_image)
49
+ return new_image
50
+
51
+ def remove_background(image: Image,
52
+ rembg_session: Any = None,
53
+ force: bool = False,
54
+ **rembg_kwargs,
55
+ ) -> Image:
56
+ do_remove = True
57
+ if image.mode == "RGBA" and image.getextrema()[3][0] < 255:
58
+ do_remove = False
59
+ do_remove = do_remove or force
60
+ if do_remove:
61
+ image = rembg.remove(image, session=rembg_session, **rembg_kwargs)
62
+ return image
63
+
64
+ def random_crop(image, crop_scale=(0.8, 0.95)):
65
+ """
66
+ 随机裁切图片
67
+ image (numpy.ndarray): (H, W, C)。
68
+ crop_scale (tuple): (min_scale, max_scale)。
69
+ """
70
+ assert isinstance(image, Image.Image), "iput must be PIL.Image.Image"
71
+ assert len(crop_scale) == 2 and 0 < crop_scale[0] <= crop_scale[1] <= 1
72
+
73
+ width, height = image.size
74
+
75
+ # 计算裁切的高度和宽度
76
+ crop_width = random.randint(int(width * crop_scale[0]), int(width * crop_scale[1]))
77
+ crop_height = random.randint(int(height * crop_scale[0]), int(height * crop_scale[1]))
78
+
79
+ # 随机选择裁切的起始点
80
+ left = random.randint(0, width - crop_width)
81
+ top = random.randint(0, height - crop_height)
82
+
83
+ # 裁切图片
84
+ cropped_image = image.crop((left, top, left + crop_width, top + crop_height))
85
+
86
+ return cropped_image
87
+
88
+ def get_crop_images(img, num=3):
89
+ cropped_images = []
90
+ for i in range(num):
91
+ cropped_images.append(random_crop(img))
92
+ return cropped_images
93
+
94
+ def background_preprocess(input_image, do_remove_background):
95
+
96
+ rembg_session = rembg.new_session() if do_remove_background else None
97
+
98
+ if do_remove_background:
99
+ input_image = remove_background(input_image, rembg_session)
100
+ input_image = resize_foreground(input_image, 0.85)
101
+
102
+ return input_image
103
+
104
+ def remove_outliers_and_average(tensor, threshold=1.5):
105
+ assert tensor.dim() == 1, "dimension of input Tensor must equal to 1"
106
+
107
+ q1 = torch.quantile(tensor, 0.25)
108
+ q3 = torch.quantile(tensor, 0.75)
109
+ iqr = q3 - q1
110
+
111
+ lower_bound = q1 - threshold * iqr
112
+ upper_bound = q3 + threshold * iqr
113
+
114
+ non_outliers = tensor[(tensor >= lower_bound) & (tensor <= upper_bound)]
115
+
116
+ if len(non_outliers) == 0:
117
+ return tensor.mean().item()
118
+
119
+ return non_outliers.mean().item()
120
+
121
+
122
+ def remove_outliers_and_average_circular(tensor, threshold=1.5):
123
+ assert tensor.dim() == 1, "dimension of input Tensor must equal to 1"
124
+
125
+ # 将角度转换为二维平面上的点
126
+ radians = tensor * torch.pi / 180.0
127
+ x_coords = torch.cos(radians)
128
+ y_coords = torch.sin(radians)
129
+
130
+ # 计算平均向量
131
+ mean_x = torch.mean(x_coords)
132
+ mean_y = torch.mean(y_coords)
133
+
134
+ differences = torch.sqrt((x_coords - mean_x) * (x_coords - mean_x) + (y_coords - mean_y) * (y_coords - mean_y))
135
+
136
+ # 计算四分位数和 IQR
137
+ q1 = torch.quantile(differences, 0.25)
138
+ q3 = torch.quantile(differences, 0.75)
139
+ iqr = q3 - q1
140
+
141
+ # 计算上下限
142
+ lower_bound = q1 - threshold * iqr
143
+ upper_bound = q3 + threshold * iqr
144
+
145
+ # 筛选非离群点
146
+ non_outliers = tensor[(differences >= lower_bound) & (differences <= upper_bound)]
147
+
148
+ if len(non_outliers) == 0:
149
+ mean_angle = torch.atan2(mean_y, mean_x) * 180.0 / torch.pi
150
+ mean_angle = (mean_angle + 360) % 360
151
+ return mean_angle # 如果没有非离群点,返回 None
152
+
153
+ # 对非离群点再次计算平均向量
154
+ radians = non_outliers * torch.pi / 180.0
155
+ x_coords = torch.cos(radians)
156
+ y_coords = torch.sin(radians)
157
+
158
+ mean_x = torch.mean(x_coords)
159
+ mean_y = torch.mean(y_coords)
160
+
161
+ mean_angle = torch.atan2(mean_y, mean_x) * 180.0 / torch.pi
162
+ mean_angle = (mean_angle + 360) % 360
163
+
164
+ return mean_angle
165
+
166
+ def scale(x):
167
+ # print(x)
168
+ # if abs(x[0])<0.1 and abs(x[1])<0.1:
169
+
170
+ # return x*5
171
+ # else:
172
+ # return x
173
+ return x*3
174
+
175
+ def get_proj2D_XYZ(phi, theta, gamma):
176
+ x = np.array([-1*np.sin(phi)*np.cos(gamma) - np.cos(phi)*np.sin(theta)*np.sin(gamma), np.sin(phi)*np.sin(gamma) - np.cos(phi)*np.sin(theta)*np.cos(gamma)])
177
+ y = np.array([-1*np.cos(phi)*np.cos(gamma) + np.sin(phi)*np.sin(theta)*np.sin(gamma), np.cos(phi)*np.sin(gamma) + np.sin(phi)*np.sin(theta)*np.cos(gamma)])
178
+ z = np.array([np.cos(theta)*np.sin(gamma), np.cos(theta)*np.cos(gamma)])
179
+ x = scale(x)
180
+ y = scale(y)
181
+ z = scale(z)
182
+ return x, y, z
183
+
184
+ # 绘制3D坐标轴
185
+ def draw_axis(ax, origin, vector, color, label=None):
186
+ ax.quiver(origin[0], origin[1], vector[0], vector[1], angles='xy', scale_units='xy', scale=1, color=color)
187
+ if label!=None:
188
+ ax.text(origin[0] + vector[0] * 1.1, origin[1] + vector[1] * 1.1, label, color=color, fontsize=12)
189
+
190
+ def matplotlib_2D_arrow(angles, rm_bkg_img):
191
+ fig, ax = plt.subplots(figsize=(8, 8))
192
+
193
+ # 设置旋转角度
194
+ phi = np.radians(angles[0])
195
+ theta = np.radians(angles[1])
196
+ gamma = np.radians(-1*angles[2])
197
+
198
+ w, h = rm_bkg_img.size
199
+ if h>w:
200
+ extent = [-5*w/h, 5*w/h, -5, 5]
201
+ else:
202
+ extent = [-5, 5, -5*h/w, 5*h/w]
203
+ ax.imshow(rm_bkg_img, extent=extent, zorder=0, aspect ='auto') # extent 设置图片的显示范围
204
+
205
+ origin = np.array([0, 0])
206
+
207
+ # 旋转后的向量
208
+ rot_x, rot_y, rot_z = get_proj2D_XYZ(phi, theta, gamma)
209
+
210
+ # draw arrow
211
+ arrow_attr = [{'point':rot_x, 'color':'r', 'label':'front'},
212
+ {'point':rot_y, 'color':'g', 'label':'right'},
213
+ {'point':rot_z, 'color':'b', 'label':'top'}]
214
+
215
+ if phi> 45 and phi<=225:
216
+ order = [0,1,2]
217
+ elif phi > 225 and phi < 315:
218
+ order = [2,0,1]
219
+ else:
220
+ order = [2,1,0]
221
+
222
+ for i in range(3):
223
+ draw_axis(ax, origin, arrow_attr[order[i]]['point'], arrow_attr[order[i]]['color'], arrow_attr[order[i]]['label'])
224
+ # draw_axis(ax, origin, rot_y, 'g', label='right')
225
+ # draw_axis(ax, origin, rot_z, 'b', label='top')
226
+ # draw_axis(ax, origin, rot_x, 'r', label='front')
227
+
228
+ # 关闭坐标轴和网格
229
+ ax.set_axis_off()
230
+ ax.grid(False)
231
+
232
+ # 设置坐标范围
233
+ ax.set_xlim(-5, 5)
234
+ ax.set_ylim(-5, 5)
235
+
236
+ def figure_to_img(fig):
237
+ with io.BytesIO() as buf:
238
+ fig.savefig(buf, format='JPG', bbox_inches='tight')
239
+ buf.seek(0)
240
+ image = Image.open(buf).copy()
241
+ return image
242
+
243
+ from render import render, Model
244
+ import math
245
+ axis_model = Model("./axis.obj", texture_filename="./axis.png")
246
+ def render_3D_axis(phi, theta, gamma):
247
+ radius = 240
248
+ # camera_location = [radius * math.cos(phi), radius * math.sin(phi), radius * math.tan(theta)]
249
+ # print(camera_location)
250
+ camera_location = [-1*radius * math.cos(phi), -1*radius * math.tan(theta), radius * math.sin(phi)]
251
+ img = render(
252
+ # Model("res/jinx.obj", texture_filename="res/jinx.tga"),
253
+ axis_model,
254
+ height=512,
255
+ width=512,
256
+ filename="tmp_render.png",
257
+ cam_loc = camera_location
258
+ )
259
+ img = img.rotate(gamma)
260
+ return img
261
+
262
+ def overlay_images_with_scaling(center_image: Image.Image, background_image, target_size=(512, 512)):
263
+ """
264
+ 调整前景图像大小为 512x512,将背景图像缩放以适配,并中心对齐叠加
265
+ :param center_image: 前景图像
266
+ :param background_image: 背景图像
267
+ :param target_size: 前景图像的目标大小,默认 (512, 512)
268
+ :return: 叠加后的图像
269
+ """
270
+ # 确保输入图像为 RGBA 模式
271
+ if center_image.mode != "RGBA":
272
+ center_image = center_image.convert("RGBA")
273
+ if background_image.mode != "RGBA":
274
+ background_image = background_image.convert("RGBA")
275
+
276
+ # 调整前景图像大小
277
+ center_image = center_image.resize(target_size)
278
+
279
+ # 缩放背景图像,确保其适合前景图像的尺寸
280
+ bg_width, bg_height = background_image.size
281
+
282
+ # 按宽度或高度等比例缩放背景
283
+ scale = target_size[0] / max(bg_width, bg_height)
284
+ new_width = int(bg_width * scale)
285
+ new_height = int(bg_height * scale)
286
+ resized_background = background_image.resize((new_width, new_height))
287
+ # 计算需要的填充量
288
+ pad_width = target_size[0] - new_width
289
+ pad_height = target_size[0] - new_height
290
+
291
+ # 计算上下左右的 padding
292
+ left = pad_width // 2
293
+ right = pad_width - left
294
+ top = pad_height // 2
295
+ bottom = pad_height - top
296
+
297
+ # 添加 padding
298
+ resized_background = ImageOps.expand(resized_background, border=(left, top, right, bottom), fill=(255,255,255,255))
299
+
300
+ # 将前景图像叠加到背景图像上
301
+ result = resized_background.copy()
302
+ result.paste(center_image, (0, 0), mask=center_image)
303
+
304
+ return result
vision_tower.py ADDED
@@ -0,0 +1,161 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+ import torch.nn.init as init
4
+ import torch.nn.functional as F
5
+
6
+ from paths import *
7
+
8
+ from typing import Dict, List, Optional, Set, Tuple, Union
9
+ from transformers import AutoImageProcessor, AutoModel, Dinov2Model
10
+ from transformers.models.dinov2.modeling_dinov2 import Dinov2Embeddings
11
+ from transformers.models.dinov2.configuration_dinov2 import Dinov2Config
12
+ import numpy as np
13
+ from contextlib import nullcontext
14
+
15
+ def get_activation(activation):
16
+ if activation.lower() == 'gelu':
17
+ return nn.GELU()
18
+ elif activation.lower() == 'rrelu':
19
+ return nn.RReLU(inplace=True)
20
+ elif activation.lower() == 'selu':
21
+ return nn.SELU(inplace=True)
22
+ elif activation.lower() == 'silu':
23
+ return nn.SiLU(inplace=True)
24
+ elif activation.lower() == 'hardswish':
25
+ return nn.Hardswish(inplace=True)
26
+ elif activation.lower() == 'leakyrelu':
27
+ return nn.LeakyReLU(inplace=True)
28
+ elif activation.lower() == 'sigmoid':
29
+ return nn.Sigmoid()
30
+ elif activation.lower() == 'tanh':
31
+ return nn.Tanh()
32
+ else:
33
+ return nn.ReLU(inplace=True)
34
+
35
+
36
+
37
+ class MLP_dim(nn.Module):
38
+ def __init__(
39
+ self, in_dim=512, out_dim=1024, bias=True, activation='relu'):
40
+ super().__init__()
41
+ self.act = get_activation(activation)
42
+ self.net1 = nn.Sequential(
43
+ nn.Linear(in_dim, int(out_dim), bias=bias),
44
+ nn.BatchNorm1d(int(out_dim)),
45
+ self.act
46
+ )
47
+ self.net2 = nn.Sequential(
48
+ nn.Linear(int(out_dim), out_dim, bias=bias),
49
+ nn.BatchNorm1d(out_dim)
50
+ )
51
+
52
+ def forward(self, x):
53
+ return self.net2(self.net1(x))
54
+
55
+ class FLIP_Dinov2Embeddings(Dinov2Embeddings):
56
+ """
57
+ Construct the CLS token, mask token, position and patch embeddings.
58
+ """
59
+
60
+ def __init__(self, config: Dinov2Config) -> None:
61
+ super().__init__(config)
62
+
63
+ def forward(self, pixel_values: torch.Tensor, bool_masked_pos: Optional[torch.Tensor] = None) -> torch.Tensor:
64
+ batch_size, _, height, width = pixel_values.shape
65
+ target_dtype = self.patch_embeddings.projection.weight.dtype
66
+ embeddings = self.patch_embeddings(pixel_values.to(dtype=target_dtype))
67
+
68
+ # add the [CLS] token to the embedded patch tokens
69
+ cls_tokens = self.cls_token.expand(batch_size, -1, -1)
70
+ embeddings = torch.cat((cls_tokens, embeddings), dim=1)
71
+
72
+ # add positional encoding to each token
73
+ embeddings = embeddings + self.interpolate_pos_encoding(embeddings, height, width)
74
+
75
+ if bool_masked_pos is not None:
76
+ # embeddings = torch.where(
77
+ # bool_masked_pos.unsqueeze(-1), self.mask_token.to(embeddings.dtype).unsqueeze(0), embeddings
78
+ # )
79
+ B,S,D = embeddings.shape
80
+ batch_indices = torch.arange(B).unsqueeze(1)
81
+ embeddings = embeddings[batch_indices, bool_masked_pos]
82
+
83
+ embeddings = self.dropout(embeddings)
84
+
85
+ return embeddings
86
+
87
+ class FLIP_DINOv2(Dinov2Model):
88
+ def __init__(self, config):
89
+ super().__init__(config)
90
+
91
+ self.embeddings = FLIP_Dinov2Embeddings(config)
92
+
93
+ class DINOv2_MLP(nn.Module):
94
+ def __init__(self,
95
+ dino_mode,
96
+ in_dim,
97
+ out_dim,
98
+ evaluate,
99
+ mask_dino,
100
+ frozen_back
101
+ ) -> None:
102
+ super().__init__()
103
+ # self.dinov2 = AutoModel.from_pretrained(DINO_BASE)
104
+ if dino_mode == 'base':
105
+ self.dinov2 = FLIP_DINOv2.from_pretrained(DINO_BASE, cache_dir='./')
106
+ elif dino_mode == 'large':
107
+ self.dinov2 = FLIP_DINOv2.from_pretrained(DINO_LARGE, cache_dir='./')
108
+ elif dino_mode == 'small':
109
+ self.dinov2 = FLIP_DINOv2.from_pretrained(DINO_SMALL, cache_dir='./')
110
+ elif dino_mode == 'giant':
111
+ self.dinov2 = FLIP_DINOv2.from_pretrained(DINO_GIANT, cache_dir='./')
112
+
113
+ self.down_sampler = MLP_dim(in_dim=in_dim, out_dim=out_dim)
114
+ self.random_mask = False
115
+ if not evaluate:
116
+ self.init_weights(self.down_sampler)
117
+ self.random_mask = mask_dino
118
+ if frozen_back:
119
+ self.forward_mode = torch.no_grad()
120
+ else:
121
+ self.forward_mode = nullcontext()
122
+
123
+ def forward(self, img_inputs):
124
+ device = self.get_device()
125
+ # print(img_inputs['pixel_values'].shape)
126
+
127
+ with self.forward_mode:
128
+ if self.random_mask:
129
+ B = len(img_inputs['pixel_values'])
130
+ S = 256
131
+ indices = []
132
+ for i in range(B):
133
+ tmp = torch.randperm(S)[:S//2]
134
+ tmp = tmp.sort().values + 1
135
+ indices.append(tmp)
136
+ indices = torch.stack(indices, dim=0)
137
+ indices = torch.cat([torch.zeros(B, 1, dtype=torch.long, device='cpu'), indices], dim=1)
138
+ # print(indices.shape)
139
+ img_inputs['bool_masked_pos'] = indices.to(device)
140
+
141
+ dino_outputs = self.dinov2(**img_inputs)
142
+ dino_seq = dino_outputs.last_hidden_state
143
+ # B,S,_ = dino_seq.shape
144
+ # dino_seq = dino_seq.view(B*S,-1)
145
+ dino_seq = dino_seq[:,0,:]
146
+
147
+ down_sample_out = self.down_sampler(dino_seq)
148
+ # down_sample_out = down_sample_out.view(B,S,-1)
149
+ # down_sample_out = down_sample_out[:,0,:]
150
+
151
+ return down_sample_out
152
+
153
+ def get_device(self):
154
+ return next(self.parameters()).device
155
+
156
+ def init_weights(self, m):
157
+ if isinstance(m, nn.Linear):
158
+ init.xavier_uniform_(m.weight)
159
+ if m.bias is not None:
160
+ init.constant_(m.bias, 0)
161
+