AhmadMustafa Claude commited on
Commit
c2da0b5
·
1 Parent(s): e8fcd6a

Initial commit with Gradio inference setup

Browse files

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude <[email protected]>

Files changed (3) hide show
  1. .gitattributes +1 -0
  2. .vscode/settings.json +5 -0
  3. gradio_inference.py +289 -0
.gitattributes CHANGED
@@ -1,3 +1,4 @@
 
1
  *.jpg filter=lfs diff=lfs merge=lfs -text
2
  *.jpeg filter=lfs diff=lfs merge=lfs -text
3
  *.png filter=lfs diff=lfs merge=lfs -text
 
1
+ *.mp4 filter=lfs diff=lfs merge=lfs -text
2
  *.jpg filter=lfs diff=lfs merge=lfs -text
3
  *.jpeg filter=lfs diff=lfs merge=lfs -text
4
  *.png filter=lfs diff=lfs merge=lfs -text
.vscode/settings.json ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ {
2
+ "python-envs.defaultEnvManager": "ms-python.python:conda",
3
+ "python-envs.defaultPackageManager": "ms-python.python:conda",
4
+ "python-envs.pythonProjects": []
5
+ }
gradio_inference.py ADDED
@@ -0,0 +1,289 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import time
2
+
3
+ import gradio as gr
4
+ import spaces
5
+ import torch
6
+ from diffusers import AutoencoderKLCogVideoX, CogVideoXDDIMScheduler
7
+ from diffusers.utils import export_to_video
8
+ from PIL import Image
9
+ from transformers import T5EncoderModel, T5Tokenizer
10
+
11
+ from cogvideo_transformer import CustomCogVideoXTransformer3DModel
12
+ from EF_Net import EF_Net
13
+ from Sci_Fi_inbetweening_pipeline import CogVideoXEFNetInbetweeningPipeline
14
+
15
+ # Global variables for the pipeline
16
+ pipe = None
17
+ device = "cuda" if torch.cuda.is_available() else "cpu"
18
+
19
+
20
+ @spaces.GPU
21
+ def load_pipeline(
22
+ pretrained_model_path="THUDM/CogVideoX-5b",
23
+ ef_net_path="weights/EF_Net.pth",
24
+ dtype_str="bfloat16",
25
+ ):
26
+ """Load the Sci-Fi pipeline"""
27
+ global pipe
28
+
29
+ dtype = torch.float16 if dtype_str == "float16" else torch.bfloat16
30
+
31
+ # Load models
32
+ tokenizer = T5Tokenizer.from_pretrained(
33
+ pretrained_model_path, subfolder="tokenizer"
34
+ )
35
+ text_encoder = T5EncoderModel.from_pretrained(
36
+ pretrained_model_path, subfolder="text_encoder"
37
+ )
38
+ transformer = CustomCogVideoXTransformer3DModel.from_pretrained(
39
+ pretrained_model_path, subfolder="transformer"
40
+ )
41
+ vae = AutoencoderKLCogVideoX.from_pretrained(pretrained_model_path, subfolder="vae")
42
+ scheduler = CogVideoXDDIMScheduler.from_pretrained(
43
+ pretrained_model_path, subfolder="scheduler"
44
+ )
45
+
46
+ # Load EF-Net
47
+ EF_Net_model = (
48
+ EF_Net(num_layers=4, downscale_coef=8, in_channels=2, num_attention_heads=48)
49
+ .requires_grad_(False)
50
+ .eval()
51
+ )
52
+
53
+ ckpt = torch.load(ef_net_path, map_location="cpu", weights_only=False)
54
+ EF_Net_state_dict = {name: params for name, params in ckpt["state_dict"].items()}
55
+ m, u = EF_Net_model.load_state_dict(EF_Net_state_dict, strict=False)
56
+ print(f"[EF-Net loaded] Missing: {len(m)} | Unexpected: {len(u)}")
57
+
58
+ # Create pipeline
59
+ pipe = CogVideoXEFNetInbetweeningPipeline(
60
+ tokenizer=tokenizer,
61
+ text_encoder=text_encoder,
62
+ transformer=transformer,
63
+ vae=vae,
64
+ EF_Net_model=EF_Net_model,
65
+ scheduler=scheduler,
66
+ )
67
+ pipe.scheduler = CogVideoXDDIMScheduler.from_config(
68
+ pipe.scheduler.config, timestep_spacing="trailing"
69
+ )
70
+
71
+ pipe.to(device)
72
+ pipe = pipe.to(dtype=dtype)
73
+
74
+ pipe.vae.enable_slicing()
75
+ pipe.vae.enable_tiling()
76
+
77
+ return "Pipeline loaded successfully!"
78
+
79
+
80
+ @spaces.GPU
81
+ def generate_inbetweening(
82
+ first_image: Image.Image,
83
+ last_image: Image.Image,
84
+ prompt: str,
85
+ num_frames: int = 49,
86
+ guidance_scale: float = 6.0,
87
+ ef_net_weights: float = 1.0,
88
+ ef_net_guidance_start: float = 0.0,
89
+ ef_net_guidance_end: float = 1.0,
90
+ seed: int = 42,
91
+ progress=gr.Progress(),
92
+ ):
93
+ """Generate frame inbetweening video"""
94
+ global pipe
95
+
96
+ if pipe is None:
97
+ return None, "Please load the pipeline first!"
98
+
99
+ if first_image is None or last_image is None:
100
+ return None, "Please upload both start and end frames!"
101
+
102
+ if not prompt.strip():
103
+ return None, "Please provide a text prompt!"
104
+
105
+ try:
106
+ progress(0, desc="Starting generation...")
107
+ start_time = time.time()
108
+
109
+ # Generate video
110
+ progress(0.2, desc="Processing frames...")
111
+ video_frames = pipe(
112
+ first_image=first_image,
113
+ last_image=last_image,
114
+ prompt=prompt,
115
+ num_frames=num_frames,
116
+ use_dynamic_cfg=False,
117
+ guidance_scale=guidance_scale,
118
+ generator=torch.Generator(device=device).manual_seed(seed),
119
+ EF_Net_weights=ef_net_weights,
120
+ EF_Net_guidance_start=ef_net_guidance_start,
121
+ EF_Net_guidance_end=ef_net_guidance_end,
122
+ ).frames[0]
123
+
124
+ progress(0.9, desc="Exporting video...")
125
+
126
+ # Export video
127
+ output_path = f"output_{int(time.time())}.mp4"
128
+ export_to_video(video_frames, output_path, fps=7)
129
+
130
+ elapsed_time = time.time() - start_time
131
+ status_msg = f"Video generated successfully in {elapsed_time:.2f}s"
132
+
133
+ progress(1.0, desc="Done!")
134
+ return output_path, status_msg
135
+
136
+ except Exception as e:
137
+ return None, f"Error: {str(e)}"
138
+
139
+
140
+ # Create Gradio interface
141
+ with gr.Blocks(title="Sci-Fi: Frame Inbetweening") as demo:
142
+ gr.Markdown(
143
+ """
144
+ # Sci-Fi: Symmetric Constraint for Frame Inbetweening
145
+
146
+ Upload start and end frames to generate smooth inbetweening video.
147
+
148
+ **Note:** Make sure to load the pipeline first before generating videos.
149
+ """
150
+ )
151
+
152
+ with gr.Tab("Generate"):
153
+ with gr.Row():
154
+ with gr.Column():
155
+ first_image = gr.Image(label="Start Frame", type="pil")
156
+ last_image = gr.Image(label="End Frame", type="pil")
157
+
158
+ with gr.Column():
159
+ prompt = gr.Textbox(
160
+ label="Prompt",
161
+ placeholder="Describe the motion or content...",
162
+ lines=3,
163
+ )
164
+
165
+ with gr.Accordion("Advanced Settings", open=False):
166
+ num_frames = gr.Slider(
167
+ minimum=13,
168
+ maximum=49,
169
+ value=49,
170
+ step=12,
171
+ label="Number of Frames",
172
+ )
173
+ guidance_scale = gr.Slider(
174
+ minimum=1.0,
175
+ maximum=15.0,
176
+ value=6.0,
177
+ step=0.5,
178
+ label="Guidance Scale",
179
+ )
180
+ ef_net_weights = gr.Slider(
181
+ minimum=0.0,
182
+ maximum=2.0,
183
+ value=1.0,
184
+ step=0.1,
185
+ label="EF-Net Weights",
186
+ )
187
+ ef_net_guidance_start = gr.Slider(
188
+ minimum=0.0,
189
+ maximum=1.0,
190
+ value=0.0,
191
+ step=0.1,
192
+ label="EF-Net Guidance Start",
193
+ )
194
+ ef_net_guidance_end = gr.Slider(
195
+ minimum=0.0,
196
+ maximum=1.0,
197
+ value=1.0,
198
+ step=0.1,
199
+ label="EF-Net Guidance End",
200
+ )
201
+ seed = gr.Number(label="Seed", value=42, precision=0)
202
+
203
+ generate_btn = gr.Button("Generate Video", variant="primary", size="lg")
204
+
205
+ with gr.Row():
206
+ output_video = gr.Video(label="Generated Video")
207
+ status_text = gr.Textbox(label="Status", lines=2)
208
+
209
+ generate_btn.click(
210
+ fn=generate_inbetweening,
211
+ inputs=[
212
+ first_image,
213
+ last_image,
214
+ prompt,
215
+ num_frames,
216
+ guidance_scale,
217
+ ef_net_weights,
218
+ ef_net_guidance_start,
219
+ ef_net_guidance_end,
220
+ seed,
221
+ ],
222
+ outputs=[output_video, status_text],
223
+ )
224
+
225
+ with gr.Tab("Setup"):
226
+ gr.Markdown(
227
+ """
228
+ ## Load Pipeline
229
+
230
+ Configure and load the model before generating videos.
231
+
232
+ **Default paths:**
233
+ - Model: `THUDM/CogVideoX-5b` (or your downloaded path)
234
+ - EF-Net: `weights/EF_Net.pth`
235
+ """
236
+ )
237
+
238
+ with gr.Row():
239
+ model_path = gr.Textbox(
240
+ label="Pretrained Model Path",
241
+ value="THUDM/CogVideoX-5b",
242
+ placeholder="Path to CogVideoX model",
243
+ )
244
+ ef_net_path = gr.Textbox(
245
+ label="EF-Net Checkpoint Path",
246
+ value="weights/EF_Net.pth",
247
+ placeholder="Path to EF-Net weights",
248
+ )
249
+
250
+ dtype_choice = gr.Radio(
251
+ choices=["bfloat16", "float16"], value="bfloat16", label="Data Type"
252
+ )
253
+
254
+ load_btn = gr.Button("Load Pipeline", variant="primary")
255
+ load_status = gr.Textbox(label="Load Status", interactive=False)
256
+
257
+ load_btn.click(
258
+ fn=load_pipeline,
259
+ inputs=[model_path, ef_net_path, dtype_choice],
260
+ outputs=load_status,
261
+ )
262
+
263
+ with gr.Tab("Examples"):
264
+ gr.Markdown(
265
+ """
266
+ ## Example Inputs
267
+
268
+ Try these example frame pairs from the `example_input_pairs/` folder.
269
+ """
270
+ )
271
+
272
+ gr.Examples(
273
+ examples=[
274
+ [
275
+ "example_input_pairs/input_pair1/start.jpg",
276
+ "example_input_pairs/input_pair1/end.jpg",
277
+ "A smooth transition between frames",
278
+ ],
279
+ [
280
+ "example_input_pairs/input_pair2/start.jpg",
281
+ "example_input_pairs/input_pair2/end.jpg",
282
+ "Natural motion interpolation",
283
+ ],
284
+ ],
285
+ inputs=[first_image, last_image, prompt],
286
+ )
287
+
288
+ if __name__ == "__main__":
289
+ demo.launch()