dream2589632147 commited on
Commit
10eab12
·
verified ·
1 Parent(s): 88079b8

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +83 -71
app.py CHANGED
@@ -35,9 +35,9 @@ MAX_DURATION = round(MAX_FRAMES_MODEL / FIXED_FPS, 1)
35
  # =========================================================
36
  # LOAD PIPELINE
37
  # =========================================================
38
- print("Loading pipeline...")
39
 
40
- # 1. تحميل المكونات بدون تحديد device_map="cuda" لتجنب تعارض ZeroGPU
41
  transformer = WanTransformer3DModel.from_pretrained(
42
  MODEL_ID,
43
  subfolder="transformer",
@@ -52,51 +52,62 @@ transformer_2 = WanTransformer3DModel.from_pretrained(
52
  token=HF_TOKEN
53
  )
54
 
55
- # 2. تجميع البايبلاين
56
  pipe = WanImageToVideoPipeline.from_pretrained(
57
  MODEL_ID,
58
  transformer=transformer,
59
  transformer_2=transformer_2,
60
  torch_dtype=torch.bfloat16,
 
61
  )
62
 
63
- # 3. نقل الموديل للـ CUDA مرة واحدة هنا
64
- print("Moving pipeline to CUDA...")
65
  pipe = pipe.to("cuda")
66
 
67
  # =========================================================
68
  # LOAD LORA ADAPTERS
69
  # =========================================================
70
  print("Loading LoRA adapters...")
71
- pipe.load_lora_weights(
72
- "Kijai/WanVideo_comfy",
73
- weight_name="Lightx2v/lightx2v_I2V_14B_480p_cfg_step_distill_rank128_bf16.safetensors",
74
- adapter_name="lightx2v"
75
- )
76
- pipe.load_lora_weights(
77
- "Kijai/WanVideo_comfy",
78
- weight_name="Lightx2v/lightx2v_I2V_14B_480p_cfg_step_distill_rank128_bf16.safetensors",
79
- adapter_name="lightx2v_2",
80
- load_into_transformer_2=True
81
- )
 
82
 
83
- pipe.set_adapters(["lightx2v", "lightx2v_2"], adapter_weights=[1., 1.])
84
- pipe.fuse_lora(adapter_names=["lightx2v"], lora_scale=3., components=["transformer"])
85
- pipe.fuse_lora(adapter_names=["lightx2v_2"], lora_scale=1., components=["transformer_2"])
86
- pipe.unload_lora_weights()
 
 
 
87
 
88
  # =========================================================
89
  # QUANTIZATION & AOT OPTIMIZATION
90
  # =========================================================
91
  print("Applying quantization...")
92
- # نقلنا التكميم بعد النقل للـ GPU لضمان التوافق
93
- quantize_(pipe.text_encoder, Int8WeightOnlyConfig())
94
- quantize_(pipe.transformer, Float8DynamicActivationFloat8WeightConfig())
95
- quantize_(pipe.transformer_2, Float8DynamicActivationFloat8WeightConfig())
96
 
97
- print("Loading AOTI blocks...")
98
- aoti.aoti_blocks_load(pipe.transformer, 'zerogpu-aoti/Wan2', variant='fp8da')
99
- aoti.aoti_blocks_load(pipe.transformer_2, 'zerogpu-aoti/Wan2', variant='fp8da')
 
 
 
 
 
 
 
100
 
101
  # =========================================================
102
  # DEFAULT PROMPTS
@@ -150,26 +161,11 @@ def resize_image(image: Image.Image) -> Image.Image:
150
  def get_num_frames(duration_seconds: float):
151
  return 1 + int(np.clip(int(round(duration_seconds * FIXED_FPS)), MIN_FRAMES_MODEL, MAX_FRAMES_MODEL))
152
 
153
- def get_duration(
154
- input_image, prompt, steps, negative_prompt,
155
- duration_seconds, guidance_scale, guidance_scale_2,
156
- seed, randomize_seed, progress,
157
- ):
158
- if input_image is None:
159
- return 120
160
- BASE_FRAMES_HEIGHT_WIDTH = 81 * 832 * 624
161
- BASE_STEP_DURATION = 15
162
-
163
- width, height = resize_image(input_image).size
164
- frames = get_num_frames(duration_seconds)
165
- factor = frames * width * height / BASE_FRAMES_HEIGHT_WIDTH
166
- step_duration = BASE_STEP_DURATION * factor ** 1.5
167
- return 10 + int(steps) * step_duration
168
-
169
  # =========================================================
170
  # MAIN GENERATION FUNCTION
171
  # =========================================================
172
- @spaces.GPU(duration=get_duration)
 
173
  def generate_video(
174
  input_image,
175
  prompt,
@@ -182,31 +178,49 @@ def generate_video(
182
  randomize_seed=False,
183
  progress=gr.Progress(track_tqdm=True),
184
  ):
185
- if input_image is None:
186
- raise gr.Error("Please upload an input image.")
187
-
188
- num_frames = get_num_frames(duration_seconds)
189
- current_seed = random.randint(0, MAX_SEED) if randomize_seed else int(seed)
190
- resized_image = resize_image(input_image)
 
 
 
 
 
191
 
192
- output_frames_list = pipe(
193
- image=resized_image,
194
- prompt=prompt,
195
- negative_prompt=negative_prompt,
196
- height=resized_image.height,
197
- width=resized_image.width,
198
- num_frames=num_frames,
199
- guidance_scale=float(guidance_scale),
200
- guidance_scale_2=float(guidance_scale_2),
201
- num_inference_steps=int(steps),
202
- generator=torch.Generator(device="cuda").manual_seed(current_seed),
203
- ).frames[0]
204
 
205
- with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as tmpfile:
206
- video_path = tmpfile.name
207
-
208
- export_to_video(output_frames_list, video_path, fps=FIXED_FPS)
209
- return video_path, current_seed
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
210
 
211
  # =========================================================
212
  # GRADIO UI
@@ -285,9 +299,6 @@ with gr.Blocks(theme=gr.themes.Soft()) as demo:
285
 
286
  generate_button.click(fn=generate_video, inputs=ui_inputs, outputs=[video_output, seed_input])
287
 
288
- # ملاحظة: يمكنك حذف جزء Examples إذا استمرت الأخطاء أو إذا لم تكن الصورة مرفوعة
289
- # gr.Examples(...)
290
-
291
  # --- BOTTOM ADVERTISEMENT BANNER ---
292
  gr.HTML("""
293
  <div style="background: linear-gradient(90deg, #4f46e5, #9333ea); color: white; padding: 15px; border-radius: 10px; text-align: center; margin-top: 20px; box-shadow: 0 4px 15px rgba(0,0,0,0.1);">
@@ -306,4 +317,5 @@ with gr.Blocks(theme=gr.themes.Soft()) as demo:
306
  """)
307
 
308
  if __name__ == "__main__":
309
- demo.queue().launch(mcp_server=True)
 
 
35
  # =========================================================
36
  # LOAD PIPELINE
37
  # =========================================================
38
+ print("Loading pipeline components...")
39
 
40
+ # تحميل المكونات أولاً بدون نقلها للـ GPU لتوفير الذاكرة أثناء التحميل
41
  transformer = WanTransformer3DModel.from_pretrained(
42
  MODEL_ID,
43
  subfolder="transformer",
 
52
  token=HF_TOKEN
53
  )
54
 
55
+ print("Assembling pipeline...")
56
  pipe = WanImageToVideoPipeline.from_pretrained(
57
  MODEL_ID,
58
  transformer=transformer,
59
  transformer_2=transformer_2,
60
  torch_dtype=torch.bfloat16,
61
+ token=HF_TOKEN
62
  )
63
 
64
+ # نقل الموديل إلى CUDA الآن
65
+ print("Moving to CUDA...")
66
  pipe = pipe.to("cuda")
67
 
68
  # =========================================================
69
  # LOAD LORA ADAPTERS
70
  # =========================================================
71
  print("Loading LoRA adapters...")
72
+ try:
73
+ pipe.load_lora_weights(
74
+ "Kijai/WanVideo_comfy",
75
+ weight_name="Lightx2v/lightx2v_I2V_14B_480p_cfg_step_distill_rank128_bf16.safetensors",
76
+ adapter_name="lightx2v"
77
+ )
78
+ pipe.load_lora_weights(
79
+ "Kijai/WanVideo_comfy",
80
+ weight_name="Lightx2v/lightx2v_I2V_14B_480p_cfg_step_distill_rank128_bf16.safetensors",
81
+ adapter_name="lightx2v_2",
82
+ load_into_transformer_2=True
83
+ )
84
 
85
+ pipe.set_adapters(["lightx2v", "lightx2v_2"], adapter_weights=[1., 1.])
86
+ pipe.fuse_lora(adapter_names=["lightx2v"], lora_scale=3., components=["transformer"])
87
+ pipe.fuse_lora(adapter_names=["lightx2v_2"], lora_scale=1., components=["transformer_2"])
88
+ pipe.unload_lora_weights()
89
+ print("LoRA loaded and fused successfully.")
90
+ except Exception as e:
91
+ print(f"Warning: Failed to load LoRA. Continuing without it. Error: {e}")
92
 
93
  # =========================================================
94
  # QUANTIZATION & AOT OPTIMIZATION
95
  # =========================================================
96
  print("Applying quantization...")
97
+ # تنظيف الذاكرة قبل العمليات الثقيلة
98
+ torch.cuda.empty_cache()
99
+ gc.collect()
 
100
 
101
+ try:
102
+ quantize_(pipe.text_encoder, Int8WeightOnlyConfig())
103
+ quantize_(pipe.transformer, Float8DynamicActivationFloat8WeightConfig())
104
+ quantize_(pipe.transformer_2, Float8DynamicActivationFloat8WeightConfig())
105
+
106
+ print("Loading AOTI blocks...")
107
+ aoti.aoti_blocks_load(pipe.transformer, 'zerogpu-aoti/Wan2', variant='fp8da')
108
+ aoti.aoti_blocks_load(pipe.transformer_2, 'zerogpu-aoti/Wan2', variant='fp8da')
109
+ except Exception as e:
110
+ print(f"Warning: Quantization/AOTI failed. Running in standard mode might OOM. Error: {e}")
111
 
112
  # =========================================================
113
  # DEFAULT PROMPTS
 
161
  def get_num_frames(duration_seconds: float):
162
  return 1 + int(np.clip(int(round(duration_seconds * FIXED_FPS)), MIN_FRAMES_MODEL, MAX_FRAMES_MODEL))
163
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
164
  # =========================================================
165
  # MAIN GENERATION FUNCTION
166
  # =========================================================
167
+ # زيادة الوقت المسموح به إلى 180 ثانية لتجنب التايم أوت
168
+ @spaces.GPU(duration=180)
169
  def generate_video(
170
  input_image,
171
  prompt,
 
178
  randomize_seed=False,
179
  progress=gr.Progress(track_tqdm=True),
180
  ):
181
+ # تنظيف الذاكرة في بداية الدالة
182
+ gc.collect()
183
+ torch.cuda.empty_cache()
184
+
185
+ try:
186
+ if input_image is None:
187
+ raise gr.Error("Please upload an input image.")
188
+
189
+ num_frames = get_num_frames(duration_seconds)
190
+ current_seed = random.randint(0, MAX_SEED) if randomize_seed else int(seed)
191
+ resized_image = resize_image(input_image)
192
 
193
+ print(f"Generating video with seed: {current_seed}, frames: {num_frames}")
 
 
 
 
 
 
 
 
 
 
 
194
 
195
+ output_frames_list = pipe(
196
+ image=resized_image,
197
+ prompt=prompt,
198
+ negative_prompt=negative_prompt,
199
+ height=resized_image.height,
200
+ width=resized_image.width,
201
+ num_frames=num_frames,
202
+ guidance_scale=float(guidance_scale),
203
+ guidance_scale_2=float(guidance_scale_2),
204
+ num_inference_steps=int(steps),
205
+ generator=torch.Generator(device="cuda").manual_seed(current_seed),
206
+ ).frames[0]
207
+
208
+ with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as tmpfile:
209
+ video_path = tmpfile.name
210
+
211
+ export_to_video(output_frames_list, video_path, fps=FIXED_FPS)
212
+
213
+ # تنظيف الذاكرة بعد الانتهاء
214
+ del output_frames_list
215
+ torch.cuda.empty_cache()
216
+
217
+ return video_path, current_seed
218
+
219
+ except Exception as e:
220
+ # طباعة الخطأ الحقيقي في الكونسول
221
+ print(f"Error during generation: {e}")
222
+ # إعادة رفع الخطأ ليظهر للمستخدم
223
+ raise gr.Error(f"Generation failed: {str(e)}")
224
 
225
  # =========================================================
226
  # GRADIO UI
 
299
 
300
  generate_button.click(fn=generate_video, inputs=ui_inputs, outputs=[video_output, seed_input])
301
 
 
 
 
302
  # --- BOTTOM ADVERTISEMENT BANNER ---
303
  gr.HTML("""
304
  <div style="background: linear-gradient(90deg, #4f46e5, #9333ea); color: white; padding: 15px; border-radius: 10px; text-align: center; margin-top: 20px; box-shadow: 0 4px 15px rgba(0,0,0,0.1);">
 
317
  """)
318
 
319
  if __name__ == "__main__":
320
+ # تم إزالة mcp_server=True لأنه يسبب مشاكل
321
+ demo.queue().launch()