Andrzej Kryszpiniuk commited on
Commit
7c1d300
·
1 Parent(s): 726dc4d

Update to clean FaceLift (no Gemini)

Browse files
Files changed (2) hide show
  1. app.py +3 -23
  2. gradio_app.py +112 -196
app.py CHANGED
@@ -1,25 +1,5 @@
1
- """
2
- FaceLift + Gemini - HuggingFace Space Entry Point
3
- This is the main file that HuggingFace Space will run.
4
- """
5
-
6
- import os
7
- # Set OMP_NUM_THREADS to 1 to avoid libgomp crash in HF Spaces
8
- os.environ["OMP_NUM_THREADS"] = "1"
9
-
10
- import gradio as gr
11
  from gradio_app import create_demo
12
 
13
- # HuggingFace Spaces automatically provides GPU
14
- # The app will use environment variables for API keys
15
-
16
- if __name__ == "__main__":
17
- # Create the Gradio interface
18
- demo = create_demo()
19
-
20
- # Launch with HuggingFace Space settings
21
- demo.queue().launch(
22
- server_name="0.0.0.0",
23
- server_port=7860,
24
- show_error=True
25
- )
 
 
 
 
 
 
 
 
 
 
 
1
  from gradio_app import create_demo
2
 
3
+ demo = create_demo()
4
+ demo.queue(max_size=10)
5
+ demo.launch()
 
 
 
 
 
 
 
 
 
 
gradio_app.py CHANGED
@@ -13,8 +13,8 @@
13
  # limitations under the License.
14
 
15
  """
16
- FaceLift: Single Image 3D Face Reconstruction (Gemini Edition)
17
- Generates 3D head models from single images using Gemini 2.0 Flash and GS-LRM.
18
  """
19
 
20
  import json
@@ -30,39 +30,45 @@ from einops import rearrange
30
  from PIL import Image
31
  from huggingface_hub import snapshot_download
32
 
 
 
33
  from utils_folder.face_utils import preprocess_image, preprocess_image_without_cropping
34
- from gemini_generator import GeminiGenerator
35
 
36
  # HuggingFace repository configuration
37
  HF_REPO_ID = "wlyu/OpenFaceLift"
38
 
39
  def download_weights_from_hf() -> Path:
40
- """Download model weights from HuggingFace if not already present."""
 
 
 
 
41
  workspace_dir = Path(__file__).parent
42
 
43
  # Check if weights already exist locally
 
44
  gslrm_path = workspace_dir / "checkpoints/gslrm/ckpt_0000000000021125.pt"
 
45
 
46
- if gslrm_path.exists():
47
  print("Using local model weights")
48
  return workspace_dir
49
 
50
  print(f"Downloading model weights from HuggingFace: {HF_REPO_ID}")
 
51
 
52
  # Download to checkpoints directory
53
- # Repo structure is 'gslrm/ckpt...', so we download to 'checkpoints' folder to get 'checkpoints/gslrm/ckpt...'
54
  snapshot_download(
55
  repo_id=HF_REPO_ID,
56
  local_dir=str(workspace_dir / "checkpoints"),
57
  local_dir_use_symlinks=False,
58
- allow_patterns=["gslrm/*"] # Only download GS-LRM
59
  )
60
 
61
  print("Model weights downloaded successfully!")
62
  return workspace_dir
63
 
64
  class FaceLiftPipeline:
65
- """Pipeline for FaceLift 3D head generation (Gemini Only)."""
66
 
67
  def __init__(self):
68
  # Download weights from HuggingFace if needed
@@ -70,18 +76,23 @@ class FaceLiftPipeline:
70
 
71
  # Setup paths
72
  self.output_dir = workspace_dir / "outputs"
 
73
  self.output_dir.mkdir(exist_ok=True)
74
 
75
  # Parameters
76
  self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
77
  self.image_size = 512
78
- self.camera_indices = [2, 1, 0, 5, 4, 3] # Front, Back, Left, Right, Top, Bottom
79
 
80
- # Initialize Gemini Generator
81
- self.gemini_generator = GeminiGenerator()
 
 
 
 
 
 
82
 
83
- # Load GS-LRM model (Reconstruction only)
84
- print("Loading GS-LRM model...")
85
  with open(workspace_dir / "configs/gslrm.yaml", "r") as f:
86
  config = edict(yaml.safe_load(f))
87
 
@@ -97,191 +108,84 @@ class FaceLiftPipeline:
97
  self.gs_lrm_model.load_state_dict(checkpoint["model"])
98
  self.gs_lrm_model.to(self.device)
99
 
 
 
 
 
 
100
  with open(workspace_dir / "utils_folder/opencv_cameras.json", 'r') as f:
101
  self.cameras_data = json.load(f)["frames"]
102
 
103
  print("Models loaded successfully!")
104
 
105
- def generate_3d_head(self, image_path, api_key, model_type="Gemini", auto_crop=True, guidance_scale=3.0,
106
  random_seed=4, num_steps=50):
107
- """Generate 3D head from single image."""
108
  try:
109
- # Update API Key if provided
110
- if api_key:
111
- self.gemini_generator.configure_key(api_key)
112
-
113
  # Setup output directory
114
  timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
115
  output_dir = self.output_dir / timestamp
116
  output_dir.mkdir(exist_ok=True)
117
 
118
- # Preprocess input
119
- original_img = np.array(Image.open(image_path))
120
-
121
- # Check for pre-generated multiview image (Grid or Strip) BEFORE cropping
122
- h, w, _ = original_img.shape
123
- aspect_ratio = w / h
124
-
125
- print(f"[DEBUG] Image dimensions: {w}x{h}, Aspect ratio: {aspect_ratio:.3f}")
126
-
127
- is_strip = 5.5 < aspect_ratio < 6.5
128
- is_grid = 1.1 < aspect_ratio < 2.0 # Widened range to catch cropped/resized grids
129
-
130
- print(f"[DEBUG] is_strip: {is_strip}, is_grid: {is_grid}")
131
-
132
-
133
  selected_views = []
134
- original_views = [] # Keep original aspect ratios for multiview composite
135
 
136
- if is_strip:
137
- print("Detected pre-generated multiview image (6x1). Skipping generation & cropping.")
138
- input_image = Image.fromarray(original_img)
139
- single_view_width = w // 6
 
 
 
 
 
 
 
140
  for i in range(6):
141
- left = i * single_view_width
142
- right = (i + 1) * single_view_width
143
- view = input_image.crop((left, 0, right, h))
144
-
145
- # Pad to square (add white borders)
146
- view_w, view_h = view.size
147
- target_size = max(view_w, view_h)
148
-
149
- # Create white canvas
150
- view_square = Image.new('RGB', (target_size, target_size), (255, 255, 255))
151
- # Paste view centered
152
- paste_x = (target_size - view_w) // 2
153
- paste_y = (target_size - view_h) // 2
154
- view_square.paste(view, (paste_x, paste_y))
155
-
156
- original_views.append(view_square.copy()) # Save square original
157
- if view_square.size != (self.image_size, self.image_size):
158
- view_square = view_square.resize((self.image_size, self.image_size))
159
- selected_views.append(view_square)
160
-
161
- elif is_grid:
162
- # Grid layout detected - could be 3x2 or 2x3
163
- # Determine which based on dimensions
164
- if w > h:
165
- # Landscape: 3x2 (3 columns, 2 rows)
166
- print(f"Detected 3x2 grid layout. Aspect Ratio: {aspect_ratio}. Skipping generation & cropping.")
167
- input_image = Image.fromarray(original_img)
168
- single_view_width = w // 3
169
- single_view_height = h // 2
170
-
171
- # Row 1: Top 3 views
172
- for i in range(3):
173
- left = i * single_view_width
174
- right = (i + 1) * single_view_width
175
- view = input_image.crop((left, 0, right, single_view_height))
176
-
177
- # Pad to square (add white borders)
178
- view_w, view_h = view.size
179
- target_size = max(view_w, view_h)
180
-
181
- # Create white canvas
182
- view_square = Image.new('RGB', (target_size, target_size), (255, 255, 255))
183
- # Paste view centered
184
- paste_x = (target_size - view_w) // 2
185
- paste_y = (target_size - view_h) // 2
186
- view_square.paste(view, (paste_x, paste_y))
187
-
188
- original_views.append(view_square.copy()) # Save square original
189
- if view_square.size != (self.image_size, self.image_size):
190
- view_square = view_square.resize((self.image_size, self.image_size))
191
- selected_views.append(view_square)
192
-
193
- # Row 2: Bottom 3 views
194
- for i in range(3):
195
- left = i * single_view_width
196
- right = (i + 1) * single_view_width
197
- view = input_image.crop((left, single_view_height, right, h))
198
-
199
- # Pad to square (add white borders)
200
- view_w, view_h = view.size
201
- target_size = max(view_w, view_h)
202
-
203
- # Create white canvas
204
- view_square = Image.new('RGB', (target_size, target_size), (255, 255, 255))
205
- # Paste view centered
206
- paste_x = (target_size - view_w) // 2
207
- paste_y = (target_size - view_h) // 2
208
- view_square.paste(view, (paste_x, paste_y))
209
-
210
- original_views.append(view_square.copy()) # Save square original
211
- if view_square.size != (self.image_size, self.image_size):
212
- view_square = view_square.resize((self.image_size, self.image_size))
213
- selected_views.append(view_square)
214
- else:
215
- # Portrait: 2x3 (2 columns, 3 rows)
216
- print(f"Detected 2x3 grid layout. Aspect Ratio: {aspect_ratio}. Skipping generation & cropping.")
217
- input_image = Image.fromarray(original_img)
218
- single_view_width = w // 2
219
- single_view_height = h // 3
220
-
221
- # Process all 6 views row by row
222
- for row in range(3):
223
- for col in range(2):
224
- left = col * single_view_width
225
- right = (col + 1) * single_view_width
226
- top = row * single_view_height
227
- bottom = (row + 1) * single_view_height
228
- view = input_image.crop((left, top, right, bottom))
229
-
230
- # Pad to square (add white borders)
231
- view_w, view_h = view.size
232
- target_size = max(view_w, view_h)
233
-
234
- # Create white canvas
235
- view_square = Image.new('RGB', (target_size, target_size), (255, 255, 255))
236
- # Paste view centered
237
- paste_x = (target_size - view_w) // 2
238
- paste_y = (target_size - view_h) // 2
239
- view_square.paste(view, (paste_x, paste_y))
240
-
241
- original_views.append(view_square.copy()) # Save original
242
- if view_square.size != (self.image_size, self.image_size):
243
- view_square = view_square.resize((self.image_size, self.image_size))
244
- selected_views.append(view_square)
245
-
246
-
247
- else:
248
- # Normal flow: Preprocess -> Generate
249
- input_image = preprocess_image(original_img) if auto_crop else \
250
- preprocess_image_without_cropping(original_img)
251
 
252
- # Gemini generation requires API key
253
- if not api_key:
254
- raise gr.Error("API Key is required for generating new views. Please provide a Gemini API Key or upload a pre-generated 2x3 or 6x1 grid image.")
 
 
 
 
 
 
 
255
 
256
- print("Generating multi-view images with Gemini...")
257
  if input_image.size != (self.image_size, self.image_size):
258
  input_image = input_image.resize((self.image_size, self.image_size))
259
 
260
- try:
261
- selected_views = self.gemini_generator.generate_multiview(input_image)
262
- original_views = [v.copy() for v in selected_views] # For Gemini, they're already square
263
- except Exception as e:
264
- raise gr.Error(f"Gemini generation failed: {str(e)}. Try uploading a pre-generated 2x3 grid instead.")
265
-
266
- # Save processed input (for reference)
267
- input_path = output_dir / "input.png"
268
- input_image.save(input_path)
269
-
270
- # Save multi-view composite (preserve original aspect ratios)
271
- # Use original_views instead of selected_views
272
- max_height = max(view.size[1] for view in original_views)
273
- total_width = sum(view.size[0] for view in original_views)
274
-
275
- multiview_image = Image.new("RGB", (total_width, max_height), (255, 255, 255))
276
- x_offset = 0
277
- for view in original_views:
278
- # Center vertically if view is shorter than max_height
279
- y_offset = (max_height - view.size[1]) // 2
280
- multiview_image.paste(view, (x_offset, y_offset))
281
- x_offset += view.size[0]
282
-
283
- multiview_path = output_dir / "multiview.png"
284
- multiview_image.save(multiview_path)
285
 
286
  # Prepare 3D reconstruction input
287
  view_arrays = [np.array(view) for view in selected_views]
@@ -337,46 +241,58 @@ class FaceLiftPipeline:
337
  output_path = output_dir / "output.png"
338
  Image.fromarray(comp_image).save(output_path)
339
 
340
- return str(input_path), str(multiview_path), str(output_path), str(ply_path)
 
 
 
 
 
 
 
 
 
 
341
 
342
  except Exception as e:
343
  raise gr.Error(f"Generation failed: {str(e)}")
344
 
345
 
346
- def create_demo():
347
- """Create and return the Gradio demo interface."""
348
  pipeline = FaceLiftPipeline()
349
 
 
 
 
 
 
 
 
350
  demo = gr.Interface(
351
  fn=pipeline.generate_3d_head,
352
- title="FaceLift: Single Image 3D Face Reconstruction (Gemini)",
353
  description="""
354
- Transform a single portrait into a complete 3D head model using Gemini 2.0 Flash and GS-LRM.
355
  """,
356
  inputs=[
357
- gr.Image(type="filepath", label="Input Portrait Image"),
358
- gr.Textbox(label="API Key (Gemini)", type="password", placeholder="Optional - only needed if generating new views", value=""),
359
- gr.Dropdown(choices=["Gemini"], value="Gemini", label="Generation Model", visible=False),
360
- gr.Checkbox(value=True, label="Auto Cropping"),
361
- gr.Slider(1.0, 10.0, 3.0, step=0.1, label="Guidance Scale (Unused)"),
362
- gr.Number(value=4, label="Random Seed (Unused)"),
363
- gr.Slider(10, 100, 50, step=5, label="Generation Steps (Unused)"),
364
  ],
365
  outputs=[
366
  gr.Image(label="Processed Input"),
367
  gr.Image(label="Multi-view Generation"),
368
  gr.Image(label="3D Reconstruction"),
 
369
  gr.File(label="3D Model (.ply)"),
370
  ],
 
371
  allow_flagging="never",
372
  )
373
 
374
- return demo
375
-
376
-
377
- def main():
378
- """Main function for local development."""
379
- demo = create_demo()
380
  demo.queue(max_size=10)
381
  demo.launch(share=True, server_name="0.0.0.0", server_port=7860, show_error=True)
382
 
 
13
  # limitations under the License.
14
 
15
  """
16
+ FaceLift: Single Image 3D Face Reconstruction
17
+ Generates 3D head models from single images using multi-view diffusion and GS-LRM.
18
  """
19
 
20
  import json
 
30
  from PIL import Image
31
  from huggingface_hub import snapshot_download
32
 
33
+ from gslrm.model.gaussians_renderer import render_turntable, imageseq2video
34
+ from mvdiffusion.pipelines.pipeline_mvdiffusion_unclip import StableUnCLIPImg2ImgPipeline
35
  from utils_folder.face_utils import preprocess_image, preprocess_image_without_cropping
 
36
 
37
  # HuggingFace repository configuration
38
  HF_REPO_ID = "wlyu/OpenFaceLift"
39
 
40
  def download_weights_from_hf() -> Path:
41
+ """Download model weights from HuggingFace if not already present.
42
+
43
+ Returns:
44
+ Path to the downloaded repository
45
+ """
46
  workspace_dir = Path(__file__).parent
47
 
48
  # Check if weights already exist locally
49
+ mvdiffusion_path = workspace_dir / "checkpoints/mvdiffusion/pipeckpts"
50
  gslrm_path = workspace_dir / "checkpoints/gslrm/ckpt_0000000000021125.pt"
51
+ prompt_embeds_path = workspace_dir / "mvdiffusion/data/fixed_prompt_embeds_6view/clr_embeds.pt"
52
 
53
+ if mvdiffusion_path.exists() and gslrm_path.exists() and prompt_embeds_path.exists():
54
  print("Using local model weights")
55
  return workspace_dir
56
 
57
  print(f"Downloading model weights from HuggingFace: {HF_REPO_ID}")
58
+ print("This may take a few minutes on first run...")
59
 
60
  # Download to checkpoints directory
 
61
  snapshot_download(
62
  repo_id=HF_REPO_ID,
63
  local_dir=str(workspace_dir / "checkpoints"),
64
  local_dir_use_symlinks=False,
 
65
  )
66
 
67
  print("Model weights downloaded successfully!")
68
  return workspace_dir
69
 
70
  class FaceLiftPipeline:
71
+ """Pipeline for FaceLift 3D head generation from single images."""
72
 
73
  def __init__(self):
74
  # Download weights from HuggingFace if needed
 
76
 
77
  # Setup paths
78
  self.output_dir = workspace_dir / "outputs"
79
+ self.examples_dir = workspace_dir / "examples"
80
  self.output_dir.mkdir(exist_ok=True)
81
 
82
  # Parameters
83
  self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
84
  self.image_size = 512
85
+ self.camera_indices = [2, 1, 0, 5, 4, 3]
86
 
87
+ # Load models
88
+ print("Loading models...")
89
+ self.mvdiffusion_pipeline = StableUnCLIPImg2ImgPipeline.from_pretrained(
90
+ str(workspace_dir / "checkpoints/mvdiffusion/pipeckpts"),
91
+ torch_dtype=torch.float16,
92
+ )
93
+ self.mvdiffusion_pipeline.unet.enable_xformers_memory_efficient_attention()
94
+ self.mvdiffusion_pipeline.to(self.device)
95
 
 
 
96
  with open(workspace_dir / "configs/gslrm.yaml", "r") as f:
97
  config = edict(yaml.safe_load(f))
98
 
 
108
  self.gs_lrm_model.load_state_dict(checkpoint["model"])
109
  self.gs_lrm_model.to(self.device)
110
 
111
+ self.color_prompt_embedding = torch.load(
112
+ workspace_dir / "mvdiffusion/data/fixed_prompt_embeds_6view/clr_embeds.pt",
113
+ map_location=self.device
114
+ )
115
+
116
  with open(workspace_dir / "utils_folder/opencv_cameras.json", 'r') as f:
117
  self.cameras_data = json.load(f)["frames"]
118
 
119
  print("Models loaded successfully!")
120
 
121
+ def generate_3d_head(self, image_path, mode="Single Image", auto_crop=True, guidance_scale=3.0,
122
  random_seed=4, num_steps=50):
123
+ """Generate 3D head from single image or 6-view strip."""
124
  try:
 
 
 
 
125
  # Setup output directory
126
  timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
127
  output_dir = self.output_dir / timestamp
128
  output_dir.mkdir(exist_ok=True)
129
 
130
+ # Load input
131
+ original_img = Image.open(image_path)
132
+ input_path = output_dir / "input.png"
133
+ original_img.save(input_path)
134
+
 
 
 
 
 
 
 
 
 
 
135
  selected_views = []
 
136
 
137
+ if mode == "6-View Strip":
138
+ print("Processing 6-View Strip...")
139
+ # Expect 6x1 strip (3072x512)
140
+ if original_img.width != self.image_size * 6 or original_img.height != self.image_size:
141
+ # Try to resize if aspect ratio is correct
142
+ if abs(original_img.width / original_img.height - 6.0) < 0.1:
143
+ original_img = original_img.resize((self.image_size * 6, self.image_size), Image.LANCZOS)
144
+ else:
145
+ raise ValueError(f"Input must be 6x1 strip (e.g. 3072x512). Got {original_img.size}")
146
+
147
+ # Split views
148
  for i in range(6):
149
+ view = original_img.crop((self.image_size * i, 0, self.image_size * (i + 1), self.image_size))
150
+ selected_views.append(view)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
151
 
152
+ # Visualization of input
153
+ multiview_image = original_img
154
+ multiview_path = output_dir / "multiview.png"
155
+ multiview_image.save(multiview_path)
156
+
157
+ else:
158
+ # Single Image Mode
159
+ input_image_arr = np.array(original_img)
160
+ input_image = preprocess_image(input_image_arr) if auto_crop else \
161
+ preprocess_image_without_cropping(input_image_arr)
162
 
 
163
  if input_image.size != (self.image_size, self.image_size):
164
  input_image = input_image.resize((self.image_size, self.image_size))
165
 
166
+ # Generate multi-view images
167
+ generator = torch.Generator(device=self.mvdiffusion_pipeline.unet.device)
168
+ generator.manual_seed(random_seed)
169
+
170
+ result = self.mvdiffusion_pipeline(
171
+ input_image, None,
172
+ prompt_embeds=self.color_prompt_embedding,
173
+ guidance_scale=guidance_scale,
174
+ num_images_per_prompt=1,
175
+ num_inference_steps=num_steps,
176
+ generator=generator,
177
+ eta=1.0,
178
+ )
179
+
180
+ selected_views = result.images[:6]
181
+
182
+ # Save multi-view composite
183
+ multiview_image = Image.new("RGB", (self.image_size * 6, self.image_size))
184
+ for i, view in enumerate(selected_views):
185
+ multiview_image.paste(view, (self.image_size * i, 0))
186
+
187
+ multiview_path = output_dir / "multiview.png"
188
+ multiview_image.save(multiview_path)
 
 
189
 
190
  # Prepare 3D reconstruction input
191
  view_arrays = [np.array(view) for view in selected_views]
 
241
  output_path = output_dir / "output.png"
242
  Image.fromarray(comp_image).save(output_path)
243
 
244
+ # Generate turntable video
245
+ turntable_frames = render_turntable(gaussians, rendering_resolution=self.image_size,
246
+ num_views=180)
247
+ turntable_frames = rearrange(turntable_frames, "h (v w) c -> v h w c", v=180)
248
+ turntable_frames = np.ascontiguousarray(turntable_frames)
249
+
250
+ turntable_path = output_dir / "turntable.mp4"
251
+ imageseq2video(turntable_frames, str(turntable_path), fps=30)
252
+
253
+ return str(input_path), str(multiview_path), str(output_path), \
254
+ str(turntable_path), str(ply_path)
255
 
256
  except Exception as e:
257
  raise gr.Error(f"Generation failed: {str(e)}")
258
 
259
 
260
+ def main():
261
+ """Run the FaceLift application."""
262
  pipeline = FaceLiftPipeline()
263
 
264
+ # Load examples (Filtered for Single Image)
265
+ examples = []
266
+ if pipeline.examples_dir.exists():
267
+ examples = [[str(f)] for f in sorted(pipeline.examples_dir.iterdir())
268
+ if f.suffix.lower() in {'.png', '.jpg', '.jpeg'}]
269
+
270
+ # Create interface
271
  demo = gr.Interface(
272
  fn=pipeline.generate_3d_head,
273
+ title="FaceLift: Single Image 3D Face Reconstruction",
274
  description="""
275
+ Transform a single portrait image OR a 6-view strip into a complete 3D head model.
276
  """,
277
  inputs=[
278
+ gr.Image(type="filepath", label="Input Image (Portrait or 6x1 Strip)"),
279
+ gr.Radio(["Single Image", "6-View Strip"], value="Single Image", label="Input Mode"),
280
+ gr.Checkbox(value=True, label="Auto Cropping (Single Image Only)"),
281
+ gr.Slider(1.0, 10.0, 3.0, step=0.1, label="Guidance Scale"),
282
+ gr.Number(value=4, label="Random Seed"),
283
+ gr.Slider(10, 100, 50, step=5, label="Generation Steps"),
 
284
  ],
285
  outputs=[
286
  gr.Image(label="Processed Input"),
287
  gr.Image(label="Multi-view Generation"),
288
  gr.Image(label="3D Reconstruction"),
289
+ gr.Video(label="Turntable Animation"),
290
  gr.File(label="3D Model (.ply)"),
291
  ],
292
+ examples=examples,
293
  allow_flagging="never",
294
  )
295
 
 
 
 
 
 
 
296
  demo.queue(max_size=10)
297
  demo.launch(share=True, server_name="0.0.0.0", server_port=7860, show_error=True)
298