File size: 5,473 Bytes
c148ec9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
import spaces
import torch
import gradio as gr
from diffusers import StableDiffusionPipeline
from torchao.quantization import quantize_, Float8DynamicActivationFloat8WeightConfig
import os

# --- 1. Model Loading and Optimization (AoT Compilation) ---

# Choose a stable diffusion model
MODEL_ID = "runwayml/stable-diffusion-v1-5"

# Initialize pipeline, disable safety checker for faster compilation and inference
# Use torch.float16 for efficiency on CUDA hardware
pipe = StableDiffusionPipeline.from_pretrained(
    MODEL_ID,
    torch_dtype=torch.float16,
    safety_checker=None,
    requires_safety_checker=False
)
pipe.to('cuda')
pipe.scheduler.set_timesteps(50) # Set max steps for consistent performance testing

print("Starting AoT Compilation...")

@spaces.GPU(duration=1500)  # Reserve maximum time for startup compilation
def compile_optimized_unet():
    # 1. Apply FP8 quantization (optional, requires H200/H100 for maximum benefit)
    try:
        quantize_(pipe.unet, Float8DynamicActivationFloat8WeightConfig())
        print("✅ Applied FP8 quantization to UNet.")
    except Exception as e:
        print(f"⚠️ FP8 Quantization failed (may require specific hardware/libraries): {e}")

    # 2. Define and capture example inputs for the UNet (the core engine)
    # Standard Stable Diffusion UNet inputs (batch_size=2 for classifier-free guidance)
    bsz = 2
    latent_model_input = torch.randn(bsz, 4, 64, 64, device="cuda", dtype=torch.float16)
    t = torch.randint(0, 1000, (bsz,), device="cuda')
    encoder_hidden_states = torch.randn(bsz, 77, 768, device="cuda", dtype=torch.float16)
    
    with spaces.aoti_capture(pipe.unet) as call:
        pipe.unet(latent_model_input, t, encoder_hidden_states)
    
    # 3. Export the model
    exported = torch.export.export(
        pipe.unet,
        args=call.args,
        kwargs=call.kwargs,
    )
    
    # 4. Compile the exported model using AoT
    return spaces.aoti_compile(exported)

# Execute compilation during startup
compiled_unet = compile_optimized_unet()
# 5. Apply compiled model to the pipeline's UNet component
spaces.aoti_apply(compiled_unet, pipe.unet)

print("✅ AoT Compilation completed successfully.")

# --- 2. Inference Function (Running on GPU) ---

@spaces.GPU(duration=60) # Standard duration for image generation
def generate_image(
    prompt: str, 
    negative_prompt: str, 
    steps: int, 
    seed: int
):
    if not prompt:
        raise gr.Error("Prompt cannot be empty.")
    
    generator = torch.Generator(device="cuda").manual_seed(seed) if seed != -1 else None
    
    steps = int(steps)
    
    # Run inference using the optimized pipeline
    result = pipe(
        prompt=prompt,
        negative_prompt=negative_prompt,
        num_inference_steps=steps,
        guidance_scale=7.5,
        generator=generator
    ).images
    
    return result

# --- 3. Gradio Interface ---

with gr.Blocks(title="Optimized Vision Model (AoT Powered)") as demo:
    gr.HTML(
        """
        <div style="text-align: center; max-width: 800px; margin: 0 auto;">
            <h1><a href="https://huggingface.co/spaces/akhaliq/anycoder" target="_blank">Built with anycoder</a></h1>
            <h2>High-Performance Creative VLM Simulator (AoT Optimized)</h2>
            <p>This demo simulates a creative Vision Language Model using AoT-compiled Stable Diffusion for lightning-fast image generation.</p>
        </div>
        """
    )
    
    with gr.Row():
        with gr.Column(scale=1):
            prompt = gr.Textbox(
                label="Prompt (Input to VLM)",
                placeholder="A futuristic city painted by Van Gogh, highly detailed.",
                lines=3
            )
            negative_prompt = gr.Textbox(
                label="Negative Prompt (What to avoid)",
                placeholder="Blurry, bad quality, low resolution",
                lines=2
            )
            
            with gr.Accordion("Generation Settings", open=True):
                steps = gr.Slider(
                    minimum=10, 
                    maximum=50, 
                    step=1, 
                    value=30, 
                    label="Inference Steps (Higher = Slower/Better)"
                )
                seed = gr.Number(
                    value=-1, 
                    label="Seed (-1 for random)"
                )
            
            generate_btn = gr.Button("Generate Image (AoT Fast!)", variant="primary")
            
        with gr.Column(scale=2):
            output_gallery = gr.Gallery(
                label="Creative VLM Output",
                show_label=True,
                height=512,
                columns=2,
                object_fit="contain"
            )

    generate_btn.click(
        fn=generate_image,
        inputs=[prompt, negative_prompt, steps, seed],
        outputs=output_gallery
    )
    
    gr.Examples(
        examples=[
            ["A majestic wolf standing on a snowy mountain peak, cinematic lighting", "ugly, deformed, low detail", 30],
            ["Cyberpunk cat sitting in a neon-lit alley, 8k, digital art", "human, blurry, messy background", 40],
            ["A vintage photograph of a space shuttle launching from a tropical island", "modern, cartoon, painting", 25]
        ],
        inputs=[prompt, negative_prompt, steps],
        outputs=output_gallery,
        fn=generate_image,
        cache_examples=False,
    )

demo.queue()
demo.launch()