Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| import torch | |
| from diffusers import LTXVideoTransformer3DModel, LTXVideoPipeline | |
| from transformers import T5EncoderModel, T5Tokenizer | |
| import spaces | |
| import numpy as np | |
| import tempfile | |
| import os | |
| import time | |
| import logging | |
| from PIL import Image | |
| import cv2 | |
| from fastapi import FastAPI, File, UploadFile, Form, HTTPException | |
| from fastapi.responses import FileResponse | |
| import uvicorn | |
| import threading | |
| import json | |
| # Configure logging | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| # Global variables for model | |
| pipe = None | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| def load_model(): | |
| """Load the LTX-Video model with optimizations""" | |
| global pipe | |
| try: | |
| logger.info("Loading LTX-Video model...") | |
| # Load the pipeline | |
| pipe = LTXVideoPipeline.from_pretrained( | |
| "Lightricks/LTX-Video-0.9.7-dev", | |
| torch_dtype=torch.bfloat16, | |
| use_safetensors=True | |
| ) | |
| # Move to device | |
| pipe = pipe.to(device) | |
| # Enable optimizations | |
| pipe.vae.enable_tiling() | |
| pipe.vae.enable_slicing() | |
| # Enable memory efficient attention if available | |
| if hasattr(pipe.unet, 'enable_xformers_memory_efficient_attention'): | |
| pipe.unet.enable_xformers_memory_efficient_attention() | |
| logger.info("Model loaded successfully!") | |
| return True | |
| except Exception as e: | |
| logger.error(f"Error loading model: {e}") | |
| return False | |
| def validate_inputs(prompt, duration, image=None): | |
| """Validate input parameters""" | |
| errors = [] | |
| if not prompt or len(prompt.strip()) == 0: | |
| errors.append("Prompt is required") | |
| if len(prompt) > 500: | |
| errors.append("Prompt must be less than 500 characters") | |
| if duration < 3 or duration > 5: | |
| errors.append("Duration must be between 3 and 5 seconds") | |
| if image is not None: | |
| try: | |
| if isinstance(image, str): | |
| img = Image.open(image) | |
| else: | |
| img = image | |
| # Check image dimensions | |
| width, height = img.size | |
| if width > 1024 or height > 1024: | |
| errors.append("Image dimensions must be less than 1024x1024") | |
| except Exception as e: | |
| errors.append(f"Invalid image: {str(e)}") | |
| return errors | |
| def frames_to_video(frames, output_path, fps=24): | |
| """Convert frames to video using OpenCV""" | |
| try: | |
| height, width = frames[0].shape[:2] | |
| fourcc = cv2.VideoWriter_fourcc(*'mp4v') | |
| out = cv2.VideoWriter(output_path, fourcc, fps, (width, height)) | |
| for frame in frames: | |
| # Convert RGB to BGR for OpenCV | |
| frame_bgr = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR) | |
| out.write(frame_bgr) | |
| out.release() | |
| return True | |
| except Exception as e: | |
| logger.error(f"Error creating video: {e}") | |
| return False | |
| def generate_video_core(prompt, negative_prompt="", duration=4, image=None): | |
| """Core video generation function with ZeroGPU decorator""" | |
| global pipe | |
| start_time = time.time() | |
| try: | |
| # Calculate number of frames (24 FPS) | |
| num_frames = int(duration * 24) | |
| # Prepare generation parameters | |
| generation_kwargs = { | |
| "prompt": prompt, | |
| "negative_prompt": negative_prompt, | |
| "num_frames": num_frames, | |
| "height": 512, | |
| "width": 768, | |
| "num_inference_steps": 30, | |
| "guidance_scale": 7.5, | |
| "generator": torch.Generator(device=device).manual_seed(42) | |
| } | |
| # Add image if provided | |
| if image is not None: | |
| if isinstance(image, str): | |
| image = Image.open(image) | |
| # Resize image to match output dimensions | |
| image = image.resize((768, 512), Image.Resampling.LANCZOS) | |
| generation_kwargs["image"] = image | |
| logger.info(f"Starting generation with {num_frames} frames...") | |
| # Generate video | |
| with torch.inference_mode(): | |
| result = pipe(**generation_kwargs) | |
| # Get the generated frames | |
| frames = result.frames[0] # First (and only) video in batch | |
| # Convert to numpy arrays if needed | |
| if torch.is_tensor(frames): | |
| frames = frames.cpu().numpy() | |
| # Ensure frames are in the right format (0-255 uint8) | |
| if frames.dtype != np.uint8: | |
| frames = (frames * 255).astype(np.uint8) | |
| # Create temporary video file | |
| temp_dir = tempfile.mkdtemp() | |
| video_path = os.path.join(temp_dir, "generated_video.mp4") | |
| # Convert frames to video | |
| success = frames_to_video(frames, video_path, fps=24) | |
| if not success: | |
| raise Exception("Failed to create video file") | |
| generation_time = time.time() - start_time | |
| logger.info(f"Video generated successfully in {generation_time:.2f} seconds") | |
| return video_path, f"Generated in {generation_time:.2f}s" | |
| except Exception as e: | |
| logger.error(f"Error generating video: {e}") | |
| raise Exception(f"Generation failed: {str(e)}") | |
| def generate_video_gradio(prompt, negative_prompt, duration, image): | |
| """Gradio interface wrapper""" | |
| try: | |
| # Validate inputs | |
| errors = validate_inputs(prompt, duration, image) | |
| if errors: | |
| return None, f"Validation errors: {'; '.join(errors)}" | |
| # Check if model is loaded | |
| if pipe is None: | |
| return None, "Model not loaded. Please wait for initialization." | |
| # Generate video | |
| video_path, status = generate_video_core(prompt, negative_prompt, duration, image) | |
| return video_path, status | |
| except Exception as e: | |
| logger.error(f"Gradio generation error: {e}") | |
| return None, f"Error: {str(e)}" | |
| # Create Gradio interface | |
| def create_gradio_interface(): | |
| with gr.Blocks(title="LTX-Video Generator", theme=gr.themes.Soft()) as demo: | |
| gr.Markdown("# 🎬 LTX-Video Generator") | |
| gr.Markdown("Generate 3-5 second videos using the LTX-Video model from Lightricks") | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| # Input controls | |
| image_input = gr.File( | |
| label="Input Image (Optional)", | |
| file_types=[".png", ".jpg", ".jpeg"], | |
| type="filepath" | |
| ) | |
| prompt_input = gr.Textbox( | |
| label="Prompt", | |
| placeholder="Describe the video you want to generate...", | |
| lines=3, | |
| max_lines=5 | |
| ) | |
| negative_prompt_input = gr.Textbox( | |
| label="Negative Prompt (Optional)", | |
| placeholder="What you don't want in the video...", | |
| lines=2, | |
| max_lines=3 | |
| ) | |
| duration_slider = gr.Slider( | |
| minimum=3, | |
| maximum=5, | |
| value=4, | |
| step=0.5, | |
| label="Duration (seconds)" | |
| ) | |
| generate_btn = gr.Button("🎬 Generate Video", variant="primary") | |
| gr.Markdown("**Estimated time:** 4-6 seconds") | |
| with gr.Column(scale=1): | |
| # Output controls | |
| video_output = gr.Video(label="Generated Video") | |
| status_output = gr.Textbox(label="Status", interactive=False) | |
| # Event handlers | |
| generate_btn.click( | |
| fn=generate_video_gradio, | |
| inputs=[prompt_input, negative_prompt_input, duration_slider, image_input], | |
| outputs=[video_output, status_output] | |
| ) | |
| # Examples | |
| gr.Examples( | |
| examples=[ | |
| ["A cat playing with a ball of yarn", "", 4, None], | |
| ["Ocean waves crashing on a beach at sunset", "", 3, None], | |
| ["A person walking through a forest", "blurry, low quality", 5, None], | |
| ], | |
| inputs=[prompt_input, negative_prompt_input, duration_slider, image_input] | |
| ) | |
| return demo | |
| # FastAPI setup | |
| app = FastAPI(title="LTX-Video API", description="Generate videos using LTX-Video model") | |
| async def api_generate_video( | |
| prompt: str = Form(..., description="Text prompt for video generation"), | |
| negative_prompt: str = Form("", description="Negative prompt (optional)"), | |
| duration: float = Form(4.0, description="Duration in seconds (3-5)"), | |
| image: UploadFile = File(None, description="Input image (optional)") | |
| ): | |
| """Generate video via API""" | |
| try: | |
| # Validate inputs | |
| image_path = None | |
| if image: | |
| # Save uploaded image temporarily | |
| temp_dir = tempfile.mkdtemp() | |
| image_path = os.path.join(temp_dir, image.filename) | |
| with open(image_path, "wb") as f: | |
| content = await image.read() | |
| f.write(content) | |
| errors = validate_inputs(prompt, duration, image_path) | |
| if errors: | |
| raise HTTPException(status_code=400, detail={"errors": errors}) | |
| if pipe is None: | |
| raise HTTPException(status_code=503, detail="Model not loaded") | |
| # Generate video | |
| video_path, status = generate_video_core(prompt, negative_prompt, duration, image_path) | |
| # Return video file | |
| return FileResponse( | |
| video_path, | |
| media_type="video/mp4", | |
| filename=f"generated_video_{int(time.time())}.mp4" | |
| ) | |
| except HTTPException: | |
| raise | |
| except Exception as e: | |
| logger.error(f"API generation error: {e}") | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| async def root(): | |
| """API documentation""" | |
| return { | |
| "message": "LTX-Video API", | |
| "endpoints": { | |
| "/generate_video": "POST - Generate video", | |
| "/docs": "GET - API documentation" | |
| }, | |
| "curl_example": """ | |
| curl -X POST "http://localhost:7860/generate_video" \\ | |
| -F "prompt=A cat playing with a ball" \\ | |
| -F "duration=4" \\ | |
| -F "negative_prompt=blurry" \\ | |
| -F "image=@your_image.jpg" \\ | |
| --output generated_video.mp4 | |
| """ | |
| } | |
| def run_api(): | |
| """Run FastAPI server""" | |
| uvicorn.run(app, host="0.0.0.0", port=7861, log_level="info") | |
| def main(): | |
| """Main function""" | |
| # Load model | |
| logger.info("Initializing LTX-Video Generator...") | |
| model_loaded = load_model() | |
| if not model_loaded: | |
| logger.error("Failed to load model. Exiting.") | |
| return | |
| # Create Gradio interface | |
| demo = create_gradio_interface() | |
| # Start API server in a separate thread | |
| api_thread = threading.Thread(target=run_api, daemon=True) | |
| api_thread.start() | |
| logger.info("API server started on http://localhost:7861") | |
| # Launch Gradio interface | |
| demo.launch( | |
| server_name="0.0.0.0", | |
| server_port=7860, | |
| share=False, | |
| show_api=False | |
| ) | |
| if __name__ == "__main__": | |
| main() |