jefrisuparjanaAI's picture
Update app.py
2e3935b verified
from flask import Flask, request, jsonify, send_file
from flask_cors import CORS
import torch
from diffusers import QwenImageControlNetModel, QwenImageControlNetInpaintPipeline
from PIL import Image
import io
import base64
import os
from pathlib import Path
import gc
app = Flask(__name__)
CORS(app)
# Global variables for lazy loading
pipe = None
device = "cuda" if torch.cuda.is_available() else "cpu"
def load_models():
"""Lazy load models when first needed"""
global pipe
if pipe is not None:
return pipe
print("Loading models...")
print(f"Using device: {device}")
base_model = "Qwen/Qwen-Image"
controlnet_model = "InstantX/Qwen-Image-ControlNet-Inpainting"
try:
# Use float16 for better compatibility
dtype = torch.float16 if device == "cuda" else torch.float32
controlnet = QwenImageControlNetModel.from_pretrained(
controlnet_model,
torch_dtype=dtype,
use_safetensors=True
)
pipe = QwenImageControlNetInpaintPipeline.from_pretrained(
base_model,
controlnet=controlnet,
torch_dtype=dtype,
use_safetensors=True
)
pipe.to(device)
# Enable memory efficient attention if available
if hasattr(pipe, 'enable_attention_slicing'):
pipe.enable_attention_slicing()
print("Models loaded successfully!")
return pipe
except Exception as e:
print(f"Error loading models: {str(e)}")
raise
def image_to_base64(image):
"""Convert PIL Image to base64 string"""
buffered = io.BytesIO()
image.save(buffered, format="PNG")
img_str = base64.b64encode(buffered.getvalue()).decode()
return img_str
def base64_to_image(base64_string):
"""Convert base64 string to PIL Image"""
if ',' in base64_string:
base64_string = base64_string.split(',')[1]
image_data = base64.b64decode(base64_string)
image = Image.open(io.BytesIO(image_data))
return image
@app.route('/')
def home():
return jsonify({
"status": "running",
"model": "Qwen-Image-ControlNet-Inpainting",
"device": "cuda" if torch.cuda.is_available() else "cpu"
})
@app.route('/inpaint', methods=['POST'])
def inpaint():
try:
# Load models on first request
pipeline = load_models()
data = request.json
# Get parameters
prompt = data.get('prompt', '')
negative_prompt = data.get('negative_prompt', '')
image_base64 = data.get('image')
mask_base64 = data.get('mask')
num_steps = int(data.get('num_steps', 30))
cfg_scale = float(data.get('cfg_scale', 4.0))
controlnet_scale = float(data.get('controlnet_scale', 1.0))
seed = int(data.get('seed', 42))
# Validate inputs
if not prompt:
return jsonify({"error": "Prompt is required"}), 400
if not image_base64 or not mask_base64:
return jsonify({"error": "Image and mask are required"}), 400
# Convert base64 to images
control_image = base64_to_image(image_base64)
control_mask = base64_to_image(mask_base64)
# Resize if too large to prevent OOM
max_size = 1024
if control_image.width > max_size or control_image.height > max_size:
ratio = max_size / max(control_image.width, control_image.height)
new_size = (int(control_image.width * ratio), int(control_image.height * ratio))
control_image = control_image.resize(new_size, Image.LANCZOS)
control_mask = control_mask.resize(new_size, Image.LANCZOS)
# Ensure mask is in L mode (grayscale)
if control_mask.mode != 'L':
control_mask = control_mask.convert('L')
# Generate image
print(f"Generating image with prompt: {prompt}")
generator = torch.Generator(device=pipeline.device).manual_seed(seed)
result = pipeline(
prompt=prompt,
negative_prompt=negative_prompt,
control_image=control_image,
control_mask=control_mask,
controlnet_conditioning_scale=controlnet_scale,
width=control_image.size[0],
height=control_image.size[1],
num_inference_steps=num_steps,
true_cfg_scale=cfg_scale,
generator=generator,
).images[0]
# Convert result to base64
result_base64 = image_to_base64(result)
# Clear memory
if device == "cuda":
torch.cuda.empty_cache()
gc.collect()
return jsonify({
"success": True,
"image": f"data:image/png;base64,{result_base64}",
"message": "Image generated successfully"
})
except Exception as e:
print(f"Error: {str(e)}")
import traceback
traceback.print_exc()
# Clear memory on error
if device == "cuda":
torch.cuda.empty_cache()
gc.collect()
return jsonify({
"success": False,
"error": str(e)
}), 500
@app.route('/health', methods=['GET'])
def health():
return jsonify({
"status": "healthy",
"cuda_available": torch.cuda.is_available(),
"device": device,
"models_loaded": pipe is not None
})
if __name__ == '__main__':
port = int(os.environ.get('PORT', 7860))
app.run(host='0.0.0.0', port=port, debug=False)