File size: 5,672 Bytes
78cd1b9
 
 
 
 
 
 
 
 
2e3935b
78cd1b9
 
 
 
2e3935b
 
 
78cd1b9
2e3935b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
78cd1b9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2e3935b
 
 
78cd1b9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2e3935b
 
 
 
 
 
 
 
78cd1b9
 
 
 
 
 
2e3935b
78cd1b9
2e3935b
78cd1b9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2e3935b
 
 
 
 
78cd1b9
 
 
 
 
 
 
 
2e3935b
 
 
 
 
 
 
 
78cd1b9
 
 
 
 
 
 
 
 
 
2e3935b
 
78cd1b9
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
from flask import Flask, request, jsonify, send_file
from flask_cors import CORS
import torch
from diffusers import QwenImageControlNetModel, QwenImageControlNetInpaintPipeline
from PIL import Image
import io
import base64
import os
from pathlib import Path
import gc

app = Flask(__name__)
CORS(app)

# Global variables for lazy loading
pipe = None
device = "cuda" if torch.cuda.is_available() else "cpu"

def load_models():
    """Lazy load models when first needed"""
    global pipe
    if pipe is not None:
        return pipe
    
    print("Loading models...")
    print(f"Using device: {device}")
    
    base_model = "Qwen/Qwen-Image"
    controlnet_model = "InstantX/Qwen-Image-ControlNet-Inpainting"
    
    try:
        # Use float16 for better compatibility
        dtype = torch.float16 if device == "cuda" else torch.float32
        
        controlnet = QwenImageControlNetModel.from_pretrained(
            controlnet_model, 
            torch_dtype=dtype,
            use_safetensors=True
        )
        
        pipe = QwenImageControlNetInpaintPipeline.from_pretrained(
            base_model, 
            controlnet=controlnet, 
            torch_dtype=dtype,
            use_safetensors=True
        )
        pipe.to(device)
        
        # Enable memory efficient attention if available
        if hasattr(pipe, 'enable_attention_slicing'):
            pipe.enable_attention_slicing()
        
        print("Models loaded successfully!")
        return pipe
        
    except Exception as e:
        print(f"Error loading models: {str(e)}")
        raise

def image_to_base64(image):
    """Convert PIL Image to base64 string"""
    buffered = io.BytesIO()
    image.save(buffered, format="PNG")
    img_str = base64.b64encode(buffered.getvalue()).decode()
    return img_str

def base64_to_image(base64_string):
    """Convert base64 string to PIL Image"""
    if ',' in base64_string:
        base64_string = base64_string.split(',')[1]
    image_data = base64.b64decode(base64_string)
    image = Image.open(io.BytesIO(image_data))
    return image

@app.route('/')
def home():
    return jsonify({
        "status": "running",
        "model": "Qwen-Image-ControlNet-Inpainting",
        "device": "cuda" if torch.cuda.is_available() else "cpu"
    })

@app.route('/inpaint', methods=['POST'])
def inpaint():
    try:
        # Load models on first request
        pipeline = load_models()
        
        data = request.json
        
        # Get parameters
        prompt = data.get('prompt', '')
        negative_prompt = data.get('negative_prompt', '')
        image_base64 = data.get('image')
        mask_base64 = data.get('mask')
        num_steps = int(data.get('num_steps', 30))
        cfg_scale = float(data.get('cfg_scale', 4.0))
        controlnet_scale = float(data.get('controlnet_scale', 1.0))
        seed = int(data.get('seed', 42))
        
        # Validate inputs
        if not prompt:
            return jsonify({"error": "Prompt is required"}), 400
        if not image_base64 or not mask_base64:
            return jsonify({"error": "Image and mask are required"}), 400
        
        # Convert base64 to images
        control_image = base64_to_image(image_base64)
        control_mask = base64_to_image(mask_base64)
        
        # Resize if too large to prevent OOM
        max_size = 1024
        if control_image.width > max_size or control_image.height > max_size:
            ratio = max_size / max(control_image.width, control_image.height)
            new_size = (int(control_image.width * ratio), int(control_image.height * ratio))
            control_image = control_image.resize(new_size, Image.LANCZOS)
            control_mask = control_mask.resize(new_size, Image.LANCZOS)
        
        # Ensure mask is in L mode (grayscale)
        if control_mask.mode != 'L':
            control_mask = control_mask.convert('L')
        
        # Generate image
        print(f"Generating image with prompt: {prompt}")
        generator = torch.Generator(device=pipeline.device).manual_seed(seed)
        
        result = pipeline(
            prompt=prompt,
            negative_prompt=negative_prompt,
            control_image=control_image,
            control_mask=control_mask,
            controlnet_conditioning_scale=controlnet_scale,
            width=control_image.size[0],
            height=control_image.size[1],
            num_inference_steps=num_steps,
            true_cfg_scale=cfg_scale,
            generator=generator,
        ).images[0]
        
        # Convert result to base64
        result_base64 = image_to_base64(result)
        
        # Clear memory
        if device == "cuda":
            torch.cuda.empty_cache()
            gc.collect()
        
        return jsonify({
            "success": True,
            "image": f"data:image/png;base64,{result_base64}",
            "message": "Image generated successfully"
        })
        
    except Exception as e:
        print(f"Error: {str(e)}")
        import traceback
        traceback.print_exc()
        
        # Clear memory on error
        if device == "cuda":
            torch.cuda.empty_cache()
            gc.collect()
            
        return jsonify({
            "success": False,
            "error": str(e)
        }), 500

@app.route('/health', methods=['GET'])
def health():
    return jsonify({
        "status": "healthy",
        "cuda_available": torch.cuda.is_available(),
        "device": device,
        "models_loaded": pipe is not None
    })

if __name__ == '__main__':
    port = int(os.environ.get('PORT', 7860))
    app.run(host='0.0.0.0', port=port, debug=False)