|
|
import os |
|
|
import time |
|
|
|
|
|
import gradio as gr |
|
|
import torch |
|
|
from diffusers import AutoencoderKLCogVideoX, CogVideoXDDIMScheduler |
|
|
from diffusers.utils import export_to_video |
|
|
from huggingface_hub import login, snapshot_download |
|
|
from PIL import Image |
|
|
from transformers import T5EncoderModel, T5Tokenizer |
|
|
|
|
|
from cogvideo_transformer import CustomCogVideoXTransformer3DModel |
|
|
from EF_Net import EF_Net |
|
|
from Sci_Fi_inbetweening_pipeline import CogVideoXEFNetInbetweeningPipeline |
|
|
|
|
|
|
|
|
try: |
|
|
token = os.environ.get("HF_TOKEN") |
|
|
if token: |
|
|
login(token=token) |
|
|
print("Successfully authenticated with Hugging Face") |
|
|
else: |
|
|
print("Warning: HF_TOKEN not found") |
|
|
except Exception as e: |
|
|
print(f"Warning: Could not authenticate with HF: {e}") |
|
|
|
|
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
|
|
|
|
|
|
def load_pipeline(dtype_str="bfloat16"): |
|
|
"""Load the Sci-Fi pipeline at startup""" |
|
|
print("Loading Sci-Fi pipeline...") |
|
|
|
|
|
dtype = torch.float16 if dtype_str == "float16" else torch.bfloat16 |
|
|
|
|
|
|
|
|
print("Downloading model repository from Hugging Face...") |
|
|
repo_path = snapshot_download( |
|
|
repo_id="LiuhanChen/Sci-Fi", |
|
|
local_dir="./Sci-Fi-models", |
|
|
token=os.environ.get("HF_TOKEN"), |
|
|
ignore_patterns=["*.md", "*.txt", ".gitattributes"], |
|
|
) |
|
|
print(f"Models downloaded to: {repo_path}") |
|
|
|
|
|
|
|
|
model_base_path = repo_path |
|
|
cogvideo_path = os.path.join(model_base_path, "CogVideoX-5b-I2V") |
|
|
ef_net_path = os.path.join( |
|
|
model_base_path, "EF_Net", "EF_Net.pt" |
|
|
) |
|
|
|
|
|
print(f"CogVideo path: {cogvideo_path}") |
|
|
print(f"EF-Net path: {ef_net_path}") |
|
|
|
|
|
|
|
|
if not os.path.exists(ef_net_path): |
|
|
|
|
|
ef_net_dir = os.path.join(model_base_path, "EF_Net") |
|
|
if os.path.exists(ef_net_dir): |
|
|
print(f"Files in EF_Net directory: {os.listdir(ef_net_dir)}") |
|
|
raise FileNotFoundError(f"EF-Net weights not found at {ef_net_path}") |
|
|
|
|
|
|
|
|
print("Loading tokenizer and text encoder...") |
|
|
tokenizer = T5Tokenizer.from_pretrained(os.path.join(cogvideo_path, "tokenizer")) |
|
|
text_encoder = T5EncoderModel.from_pretrained( |
|
|
os.path.join(cogvideo_path, "text_encoder") |
|
|
) |
|
|
|
|
|
print("Loading transformer...") |
|
|
transformer = CustomCogVideoXTransformer3DModel.from_pretrained( |
|
|
os.path.join(cogvideo_path, "transformer") |
|
|
) |
|
|
|
|
|
print("Loading VAE...") |
|
|
vae = AutoencoderKLCogVideoX.from_pretrained(os.path.join(cogvideo_path, "vae")) |
|
|
|
|
|
print("Loading scheduler...") |
|
|
scheduler = CogVideoXDDIMScheduler.from_pretrained( |
|
|
os.path.join(cogvideo_path, "scheduler") |
|
|
) |
|
|
|
|
|
|
|
|
print(f"Loading EF-Net from {ef_net_path}...") |
|
|
EF_Net_model = ( |
|
|
EF_Net(num_layers=4, downscale_coef=8, in_channels=2, num_attention_heads=48) |
|
|
.requires_grad_(False) |
|
|
.eval() |
|
|
) |
|
|
|
|
|
ckpt = torch.load(ef_net_path, map_location="cpu", weights_only=False) |
|
|
EF_Net_state_dict = {name: params for name, params in ckpt["state_dict"].items()} |
|
|
m, u = EF_Net_model.load_state_dict(EF_Net_state_dict, strict=False) |
|
|
print(f"[EF-Net loaded] Missing: {len(m)} | Unexpected: {len(u)}") |
|
|
|
|
|
|
|
|
print("Creating pipeline...") |
|
|
pipeline = CogVideoXEFNetInbetweeningPipeline( |
|
|
tokenizer=tokenizer, |
|
|
text_encoder=text_encoder, |
|
|
transformer=transformer, |
|
|
vae=vae, |
|
|
EF_Net_model=EF_Net_model, |
|
|
scheduler=scheduler, |
|
|
) |
|
|
pipeline.scheduler = CogVideoXDDIMScheduler.from_config( |
|
|
pipeline.scheduler.config, timestep_spacing="trailing" |
|
|
) |
|
|
|
|
|
print(f"Moving pipeline to {device}...") |
|
|
pipeline.to(device) |
|
|
pipeline = pipeline.to(dtype=dtype) |
|
|
|
|
|
pipeline.vae.enable_slicing() |
|
|
pipeline.vae.enable_tiling() |
|
|
|
|
|
print("Pipeline loaded successfully!") |
|
|
return pipeline |
|
|
|
|
|
|
|
|
|
|
|
print("Initializing Sci-Fi pipeline at startup...") |
|
|
pipe = load_pipeline() |
|
|
|
|
|
|
|
|
def generate_inbetweening( |
|
|
first_image: Image.Image, |
|
|
last_image: Image.Image, |
|
|
prompt: str, |
|
|
num_frames: int = 49, |
|
|
guidance_scale: float = 6.0, |
|
|
ef_net_weights: float = 1.0, |
|
|
ef_net_guidance_start: float = 0.0, |
|
|
ef_net_guidance_end: float = 1.0, |
|
|
seed: int = 42, |
|
|
progress=gr.Progress(), |
|
|
): |
|
|
"""Generate frame inbetweening video""" |
|
|
if first_image is None or last_image is None: |
|
|
return None, "Please upload both start and end frames!" |
|
|
|
|
|
if not prompt.strip(): |
|
|
return None, "Please provide a text prompt!" |
|
|
|
|
|
try: |
|
|
progress(0.2, desc="Starting generation...") |
|
|
start_time = time.time() |
|
|
|
|
|
|
|
|
progress(0.4, desc="Processing frames...") |
|
|
video_frames = pipe( |
|
|
first_image=first_image, |
|
|
last_image=last_image, |
|
|
prompt=prompt, |
|
|
num_frames=num_frames, |
|
|
use_dynamic_cfg=False, |
|
|
guidance_scale=guidance_scale, |
|
|
generator=torch.Generator(device=device).manual_seed(seed), |
|
|
EF_Net_weights=ef_net_weights, |
|
|
EF_Net_guidance_start=ef_net_guidance_start, |
|
|
EF_Net_guidance_end=ef_net_guidance_end, |
|
|
).frames[0] |
|
|
|
|
|
progress(0.9, desc="Exporting video...") |
|
|
|
|
|
|
|
|
output_path = f"output_{int(time.time())}.mp4" |
|
|
export_to_video(video_frames, output_path, fps=7) |
|
|
|
|
|
elapsed_time = time.time() - start_time |
|
|
status_msg = f"Video generated successfully in {elapsed_time:.2f}s" |
|
|
|
|
|
progress(1.0, desc="Done!") |
|
|
return output_path, status_msg |
|
|
|
|
|
except Exception as e: |
|
|
return None, f"Error: {str(e)}" |
|
|
|
|
|
|
|
|
|
|
|
with gr.Blocks(title="Sci-Fi: Frame Inbetweening") as demo: |
|
|
gr.Markdown( |
|
|
""" |
|
|
# Sci-Fi: Symmetric Constraint for Frame Inbetweening |
|
|
|
|
|
Upload start and end frames to generate smooth inbetweening video. |
|
|
|
|
|
**Model is pre-loaded and ready to use!** |
|
|
""" |
|
|
) |
|
|
|
|
|
with gr.Tab("Generate"): |
|
|
with gr.Row(): |
|
|
with gr.Column(): |
|
|
first_image = gr.Image(label="Start Frame", type="pil") |
|
|
last_image = gr.Image(label="End Frame", type="pil") |
|
|
|
|
|
with gr.Column(): |
|
|
prompt = gr.Textbox( |
|
|
label="Prompt", |
|
|
placeholder="Describe the motion or content...", |
|
|
lines=3, |
|
|
) |
|
|
|
|
|
with gr.Accordion("Advanced Settings", open=False): |
|
|
num_frames = gr.Slider( |
|
|
minimum=13, |
|
|
maximum=49, |
|
|
value=49, |
|
|
step=12, |
|
|
label="Number of Frames", |
|
|
) |
|
|
guidance_scale = gr.Slider( |
|
|
minimum=1.0, |
|
|
maximum=15.0, |
|
|
value=6.0, |
|
|
step=0.5, |
|
|
label="Guidance Scale", |
|
|
) |
|
|
ef_net_weights = gr.Slider( |
|
|
minimum=0.0, |
|
|
maximum=2.0, |
|
|
value=1.0, |
|
|
step=0.1, |
|
|
label="EF-Net Weights", |
|
|
) |
|
|
ef_net_guidance_start = gr.Slider( |
|
|
minimum=0.0, |
|
|
maximum=1.0, |
|
|
value=0.0, |
|
|
step=0.1, |
|
|
label="EF-Net Guidance Start", |
|
|
) |
|
|
ef_net_guidance_end = gr.Slider( |
|
|
minimum=0.0, |
|
|
maximum=1.0, |
|
|
value=1.0, |
|
|
step=0.1, |
|
|
label="EF-Net Guidance End", |
|
|
) |
|
|
seed = gr.Number(label="Seed", value=42, precision=0) |
|
|
|
|
|
generate_btn = gr.Button("Generate Video", variant="primary", size="lg") |
|
|
|
|
|
with gr.Row(): |
|
|
output_video = gr.Video(label="Generated Video") |
|
|
status_text = gr.Textbox(label="Status", lines=2) |
|
|
|
|
|
generate_btn.click( |
|
|
fn=generate_inbetweening, |
|
|
inputs=[ |
|
|
first_image, |
|
|
last_image, |
|
|
prompt, |
|
|
num_frames, |
|
|
guidance_scale, |
|
|
ef_net_weights, |
|
|
ef_net_guidance_start, |
|
|
ef_net_guidance_end, |
|
|
seed, |
|
|
], |
|
|
outputs=[output_video, status_text], |
|
|
) |
|
|
|
|
|
with gr.Tab("Examples"): |
|
|
gr.Markdown( |
|
|
""" |
|
|
## Example Inputs |
|
|
|
|
|
Try these example frame pairs from the `example_input_pairs/` folder. |
|
|
""" |
|
|
) |
|
|
|
|
|
gr.Examples( |
|
|
examples=[ |
|
|
[ |
|
|
"example_input_pairs/input_pair1/start.jpg", |
|
|
"example_input_pairs/input_pair1/end.jpg", |
|
|
"A smooth transition between frames", |
|
|
], |
|
|
[ |
|
|
"example_input_pairs/input_pair2/start.jpg", |
|
|
"example_input_pairs/input_pair2/end.jpg", |
|
|
"Natural motion interpolation", |
|
|
], |
|
|
], |
|
|
inputs=[first_image, last_image, prompt], |
|
|
) |
|
|
|
|
|
if __name__ == "__main__": |
|
|
print("App ready - pipeline is loaded and ready for inference!") |
|
|
demo.launch() |
|
|
|