Spaces:
Runtime error
Runtime error
| from flask import Flask, request, jsonify, send_file | |
| from flask_cors import CORS | |
| import torch | |
| from diffusers import QwenImageControlNetModel, QwenImageControlNetInpaintPipeline | |
| from PIL import Image | |
| import io | |
| import base64 | |
| import os | |
| from pathlib import Path | |
| import gc | |
| app = Flask(__name__) | |
| CORS(app) | |
| # Global variables for lazy loading | |
| pipe = None | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| def load_models(): | |
| """Lazy load models when first needed""" | |
| global pipe | |
| if pipe is not None: | |
| return pipe | |
| print("Loading models...") | |
| print(f"Using device: {device}") | |
| base_model = "Qwen/Qwen-Image" | |
| controlnet_model = "InstantX/Qwen-Image-ControlNet-Inpainting" | |
| try: | |
| # Use float16 for better compatibility | |
| dtype = torch.float16 if device == "cuda" else torch.float32 | |
| controlnet = QwenImageControlNetModel.from_pretrained( | |
| controlnet_model, | |
| torch_dtype=dtype, | |
| use_safetensors=True | |
| ) | |
| pipe = QwenImageControlNetInpaintPipeline.from_pretrained( | |
| base_model, | |
| controlnet=controlnet, | |
| torch_dtype=dtype, | |
| use_safetensors=True | |
| ) | |
| pipe.to(device) | |
| # Enable memory efficient attention if available | |
| if hasattr(pipe, 'enable_attention_slicing'): | |
| pipe.enable_attention_slicing() | |
| print("Models loaded successfully!") | |
| return pipe | |
| except Exception as e: | |
| print(f"Error loading models: {str(e)}") | |
| raise | |
| def image_to_base64(image): | |
| """Convert PIL Image to base64 string""" | |
| buffered = io.BytesIO() | |
| image.save(buffered, format="PNG") | |
| img_str = base64.b64encode(buffered.getvalue()).decode() | |
| return img_str | |
| def base64_to_image(base64_string): | |
| """Convert base64 string to PIL Image""" | |
| if ',' in base64_string: | |
| base64_string = base64_string.split(',')[1] | |
| image_data = base64.b64decode(base64_string) | |
| image = Image.open(io.BytesIO(image_data)) | |
| return image | |
| def home(): | |
| return jsonify({ | |
| "status": "running", | |
| "model": "Qwen-Image-ControlNet-Inpainting", | |
| "device": "cuda" if torch.cuda.is_available() else "cpu" | |
| }) | |
| def inpaint(): | |
| try: | |
| # Load models on first request | |
| pipeline = load_models() | |
| data = request.json | |
| # Get parameters | |
| prompt = data.get('prompt', '') | |
| negative_prompt = data.get('negative_prompt', '') | |
| image_base64 = data.get('image') | |
| mask_base64 = data.get('mask') | |
| num_steps = int(data.get('num_steps', 30)) | |
| cfg_scale = float(data.get('cfg_scale', 4.0)) | |
| controlnet_scale = float(data.get('controlnet_scale', 1.0)) | |
| seed = int(data.get('seed', 42)) | |
| # Validate inputs | |
| if not prompt: | |
| return jsonify({"error": "Prompt is required"}), 400 | |
| if not image_base64 or not mask_base64: | |
| return jsonify({"error": "Image and mask are required"}), 400 | |
| # Convert base64 to images | |
| control_image = base64_to_image(image_base64) | |
| control_mask = base64_to_image(mask_base64) | |
| # Resize if too large to prevent OOM | |
| max_size = 1024 | |
| if control_image.width > max_size or control_image.height > max_size: | |
| ratio = max_size / max(control_image.width, control_image.height) | |
| new_size = (int(control_image.width * ratio), int(control_image.height * ratio)) | |
| control_image = control_image.resize(new_size, Image.LANCZOS) | |
| control_mask = control_mask.resize(new_size, Image.LANCZOS) | |
| # Ensure mask is in L mode (grayscale) | |
| if control_mask.mode != 'L': | |
| control_mask = control_mask.convert('L') | |
| # Generate image | |
| print(f"Generating image with prompt: {prompt}") | |
| generator = torch.Generator(device=pipeline.device).manual_seed(seed) | |
| result = pipeline( | |
| prompt=prompt, | |
| negative_prompt=negative_prompt, | |
| control_image=control_image, | |
| control_mask=control_mask, | |
| controlnet_conditioning_scale=controlnet_scale, | |
| width=control_image.size[0], | |
| height=control_image.size[1], | |
| num_inference_steps=num_steps, | |
| true_cfg_scale=cfg_scale, | |
| generator=generator, | |
| ).images[0] | |
| # Convert result to base64 | |
| result_base64 = image_to_base64(result) | |
| # Clear memory | |
| if device == "cuda": | |
| torch.cuda.empty_cache() | |
| gc.collect() | |
| return jsonify({ | |
| "success": True, | |
| "image": f"data:image/png;base64,{result_base64}", | |
| "message": "Image generated successfully" | |
| }) | |
| except Exception as e: | |
| print(f"Error: {str(e)}") | |
| import traceback | |
| traceback.print_exc() | |
| # Clear memory on error | |
| if device == "cuda": | |
| torch.cuda.empty_cache() | |
| gc.collect() | |
| return jsonify({ | |
| "success": False, | |
| "error": str(e) | |
| }), 500 | |
| def health(): | |
| return jsonify({ | |
| "status": "healthy", | |
| "cuda_available": torch.cuda.is_available(), | |
| "device": device, | |
| "models_loaded": pipe is not None | |
| }) | |
| if __name__ == '__main__': | |
| port = int(os.environ.get('PORT', 7860)) | |
| app.run(host='0.0.0.0', port=port, debug=False) |