AhmadMustafa commited on
Commit
c4279ee
·
1 Parent(s): d230b19
Files changed (2) hide show
  1. README.md +1 -1
  2. gradio_inference.py +0 -289
README.md CHANGED
@@ -5,7 +5,7 @@ colorFrom: blue
5
  colorTo: purple
6
  sdk: gradio
7
  sdk_version: 4.44.0
8
- app_file: gradio_inference.py
9
  pinned: false
10
  license: apache-2.0
11
  ---
 
5
  colorTo: purple
6
  sdk: gradio
7
  sdk_version: 4.44.0
8
+ app_file: app.py
9
  pinned: false
10
  license: apache-2.0
11
  ---
gradio_inference.py DELETED
@@ -1,289 +0,0 @@
1
- import time
2
-
3
- import gradio as gr
4
- import spaces
5
- import torch
6
- from diffusers import AutoencoderKLCogVideoX, CogVideoXDDIMScheduler
7
- from diffusers.utils import export_to_video
8
- from PIL import Image
9
- from transformers import T5EncoderModel, T5Tokenizer
10
-
11
- from cogvideo_transformer import CustomCogVideoXTransformer3DModel
12
- from EF_Net import EF_Net
13
- from Sci_Fi_inbetweening_pipeline import CogVideoXEFNetInbetweeningPipeline
14
-
15
- # Global variables for the pipeline
16
- pipe = None
17
- device = "cuda" if torch.cuda.is_available() else "cpu"
18
-
19
-
20
- @spaces.GPU
21
- def load_pipeline(
22
- pretrained_model_path="THUDM/CogVideoX-5b",
23
- ef_net_path="weights/EF_Net.pth",
24
- dtype_str="bfloat16",
25
- ):
26
- """Load the Sci-Fi pipeline"""
27
- global pipe
28
-
29
- dtype = torch.float16 if dtype_str == "float16" else torch.bfloat16
30
-
31
- # Load models
32
- tokenizer = T5Tokenizer.from_pretrained(
33
- pretrained_model_path, subfolder="tokenizer"
34
- )
35
- text_encoder = T5EncoderModel.from_pretrained(
36
- pretrained_model_path, subfolder="text_encoder"
37
- )
38
- transformer = CustomCogVideoXTransformer3DModel.from_pretrained(
39
- pretrained_model_path, subfolder="transformer"
40
- )
41
- vae = AutoencoderKLCogVideoX.from_pretrained(pretrained_model_path, subfolder="vae")
42
- scheduler = CogVideoXDDIMScheduler.from_pretrained(
43
- pretrained_model_path, subfolder="scheduler"
44
- )
45
-
46
- # Load EF-Net
47
- EF_Net_model = (
48
- EF_Net(num_layers=4, downscale_coef=8, in_channels=2, num_attention_heads=48)
49
- .requires_grad_(False)
50
- .eval()
51
- )
52
-
53
- ckpt = torch.load(ef_net_path, map_location="cpu", weights_only=False)
54
- EF_Net_state_dict = {name: params for name, params in ckpt["state_dict"].items()}
55
- m, u = EF_Net_model.load_state_dict(EF_Net_state_dict, strict=False)
56
- print(f"[EF-Net loaded] Missing: {len(m)} | Unexpected: {len(u)}")
57
-
58
- # Create pipeline
59
- pipe = CogVideoXEFNetInbetweeningPipeline(
60
- tokenizer=tokenizer,
61
- text_encoder=text_encoder,
62
- transformer=transformer,
63
- vae=vae,
64
- EF_Net_model=EF_Net_model,
65
- scheduler=scheduler,
66
- )
67
- pipe.scheduler = CogVideoXDDIMScheduler.from_config(
68
- pipe.scheduler.config, timestep_spacing="trailing"
69
- )
70
-
71
- pipe.to(device)
72
- pipe = pipe.to(dtype=dtype)
73
-
74
- pipe.vae.enable_slicing()
75
- pipe.vae.enable_tiling()
76
-
77
- return "Pipeline loaded successfully!"
78
-
79
-
80
- @spaces.GPU
81
- def generate_inbetweening(
82
- first_image: Image.Image,
83
- last_image: Image.Image,
84
- prompt: str,
85
- num_frames: int = 49,
86
- guidance_scale: float = 6.0,
87
- ef_net_weights: float = 1.0,
88
- ef_net_guidance_start: float = 0.0,
89
- ef_net_guidance_end: float = 1.0,
90
- seed: int = 42,
91
- progress=gr.Progress(),
92
- ):
93
- """Generate frame inbetweening video"""
94
- global pipe
95
-
96
- if pipe is None:
97
- return None, "Please load the pipeline first!"
98
-
99
- if first_image is None or last_image is None:
100
- return None, "Please upload both start and end frames!"
101
-
102
- if not prompt.strip():
103
- return None, "Please provide a text prompt!"
104
-
105
- try:
106
- progress(0, desc="Starting generation...")
107
- start_time = time.time()
108
-
109
- # Generate video
110
- progress(0.2, desc="Processing frames...")
111
- video_frames = pipe(
112
- first_image=first_image,
113
- last_image=last_image,
114
- prompt=prompt,
115
- num_frames=num_frames,
116
- use_dynamic_cfg=False,
117
- guidance_scale=guidance_scale,
118
- generator=torch.Generator(device=device).manual_seed(seed),
119
- EF_Net_weights=ef_net_weights,
120
- EF_Net_guidance_start=ef_net_guidance_start,
121
- EF_Net_guidance_end=ef_net_guidance_end,
122
- ).frames[0]
123
-
124
- progress(0.9, desc="Exporting video...")
125
-
126
- # Export video
127
- output_path = f"output_{int(time.time())}.mp4"
128
- export_to_video(video_frames, output_path, fps=7)
129
-
130
- elapsed_time = time.time() - start_time
131
- status_msg = f"Video generated successfully in {elapsed_time:.2f}s"
132
-
133
- progress(1.0, desc="Done!")
134
- return output_path, status_msg
135
-
136
- except Exception as e:
137
- return None, f"Error: {str(e)}"
138
-
139
-
140
- # Create Gradio interface
141
- with gr.Blocks(title="Sci-Fi: Frame Inbetweening") as demo:
142
- gr.Markdown(
143
- """
144
- # Sci-Fi: Symmetric Constraint for Frame Inbetweening
145
-
146
- Upload start and end frames to generate smooth inbetweening video.
147
-
148
- **Note:** Make sure to load the pipeline first before generating videos.
149
- """
150
- )
151
-
152
- with gr.Tab("Generate"):
153
- with gr.Row():
154
- with gr.Column():
155
- first_image = gr.Image(label="Start Frame", type="pil")
156
- last_image = gr.Image(label="End Frame", type="pil")
157
-
158
- with gr.Column():
159
- prompt = gr.Textbox(
160
- label="Prompt",
161
- placeholder="Describe the motion or content...",
162
- lines=3,
163
- )
164
-
165
- with gr.Accordion("Advanced Settings", open=False):
166
- num_frames = gr.Slider(
167
- minimum=13,
168
- maximum=49,
169
- value=49,
170
- step=12,
171
- label="Number of Frames",
172
- )
173
- guidance_scale = gr.Slider(
174
- minimum=1.0,
175
- maximum=15.0,
176
- value=6.0,
177
- step=0.5,
178
- label="Guidance Scale",
179
- )
180
- ef_net_weights = gr.Slider(
181
- minimum=0.0,
182
- maximum=2.0,
183
- value=1.0,
184
- step=0.1,
185
- label="EF-Net Weights",
186
- )
187
- ef_net_guidance_start = gr.Slider(
188
- minimum=0.0,
189
- maximum=1.0,
190
- value=0.0,
191
- step=0.1,
192
- label="EF-Net Guidance Start",
193
- )
194
- ef_net_guidance_end = gr.Slider(
195
- minimum=0.0,
196
- maximum=1.0,
197
- value=1.0,
198
- step=0.1,
199
- label="EF-Net Guidance End",
200
- )
201
- seed = gr.Number(label="Seed", value=42, precision=0)
202
-
203
- generate_btn = gr.Button("Generate Video", variant="primary", size="lg")
204
-
205
- with gr.Row():
206
- output_video = gr.Video(label="Generated Video")
207
- status_text = gr.Textbox(label="Status", lines=2)
208
-
209
- generate_btn.click(
210
- fn=generate_inbetweening,
211
- inputs=[
212
- first_image,
213
- last_image,
214
- prompt,
215
- num_frames,
216
- guidance_scale,
217
- ef_net_weights,
218
- ef_net_guidance_start,
219
- ef_net_guidance_end,
220
- seed,
221
- ],
222
- outputs=[output_video, status_text],
223
- )
224
-
225
- with gr.Tab("Setup"):
226
- gr.Markdown(
227
- """
228
- ## Load Pipeline
229
-
230
- Configure and load the model before generating videos.
231
-
232
- **Default paths:**
233
- - Model: `THUDM/CogVideoX-5b` (or your downloaded path)
234
- - EF-Net: `weights/EF_Net.pth`
235
- """
236
- )
237
-
238
- with gr.Row():
239
- model_path = gr.Textbox(
240
- label="Pretrained Model Path",
241
- value="THUDM/CogVideoX-5b",
242
- placeholder="Path to CogVideoX model",
243
- )
244
- ef_net_path = gr.Textbox(
245
- label="EF-Net Checkpoint Path",
246
- value="weights/EF_Net.pth",
247
- placeholder="Path to EF-Net weights",
248
- )
249
-
250
- dtype_choice = gr.Radio(
251
- choices=["bfloat16", "float16"], value="bfloat16", label="Data Type"
252
- )
253
-
254
- load_btn = gr.Button("Load Pipeline", variant="primary")
255
- load_status = gr.Textbox(label="Load Status", interactive=False)
256
-
257
- load_btn.click(
258
- fn=load_pipeline,
259
- inputs=[model_path, ef_net_path, dtype_choice],
260
- outputs=load_status,
261
- )
262
-
263
- with gr.Tab("Examples"):
264
- gr.Markdown(
265
- """
266
- ## Example Inputs
267
-
268
- Try these example frame pairs from the `example_input_pairs/` folder.
269
- """
270
- )
271
-
272
- gr.Examples(
273
- examples=[
274
- [
275
- "example_input_pairs/input_pair1/start.jpg",
276
- "example_input_pairs/input_pair1/end.jpg",
277
- "A smooth transition between frames",
278
- ],
279
- [
280
- "example_input_pairs/input_pair2/start.jpg",
281
- "example_input_pairs/input_pair2/end.jpg",
282
- "Natural motion interpolation",
283
- ],
284
- ],
285
- inputs=[first_image, last_image, prompt],
286
- )
287
-
288
- if __name__ == "__main__":
289
- demo.launch()