# PyTorch 2.8 (temporary hack) import os os.system('pip install --upgrade --pre --extra-index-url https://download.pytorch.org/whl/nightly/cu126 "torch<2.9" spaces') # Actual demo code import gradio as gr import numpy as np import spaces import torch import random from PIL import Image import logging from diffusers import FluxKontextPipeline from diffusers.utils import load_image # Enhanced logging configuration logging.basicConfig( level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', datefmt='%Y-%m-%d %H:%M:%S' ) logger = logging.getLogger(__name__) MAX_SEED = np.iinfo(np.int32).max class GenerationError(Exception): """Custom exception for generation errors""" pass pipe = FluxKontextPipeline.from_pretrained("black-forest-labs/FLUX.1-Kontext-dev", torch_dtype=torch.bfloat16).to("cuda") # -------------------- NSFW 检测模型加载 -------------------- try: logger.info("Loading NSFW detector...") device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") from transformers import AutoProcessor, AutoModelForImageClassification nsfw_processor = AutoProcessor.from_pretrained("Falconsai/nsfw_image_detection") nsfw_model = AutoModelForImageClassification.from_pretrained( "Falconsai/nsfw_image_detection" ).to(device) logger.info("NSFW detector loaded successfully.") except Exception as e: logger.error(f"Failed to load NSFW detector: {e}") nsfw_model = None nsfw_processor = None def detect_nsfw(image: Image.Image, threshold: float = 0.5) -> bool: """Returns True if image is NSFW""" inputs = nsfw_processor(images=image, return_tensors="pt").to(device) with torch.no_grad(): outputs = nsfw_model(**inputs) probs = torch.nn.functional.softmax(outputs.logits, dim=-1) nsfw_score = probs[0][1].item() # label 1 = NSFW return nsfw_score > threshold @spaces.GPU def _infer(input_image, prompt, seed=42, randomize_seed=False, guidance_scale=2.5, steps=28, progress=gr.Progress()): """ Perform image editing using the FLUX.1 Kontext pipeline. This function takes an input image and a text prompt to generate a modified version of the image based on the provided instructions. It uses the FLUX.1 Kontext model for contextual image editing tasks. Args: input_image (PIL.Image.Image): The input image to be edited. Will be converted to RGB format if not already in that format. prompt (str): Text description of the desired edit to apply to the image. Examples: "Remove glasses", "Add a hat", "Change background to beach". seed (int, optional): Random seed for reproducible generation. Defaults to 42. Must be between 0 and MAX_SEED (2^31 - 1). randomize_seed (bool, optional): If True, generates a random seed instead of using the provided seed value. Defaults to False. guidance_scale (float, optional): Controls how closely the model follows the prompt. Higher values mean stronger adherence to the prompt but may reduce image quality. Range: 1.0-10.0. Defaults to 2.5. steps (int, optional): Controls how many steps to run the diffusion model for. Range: 1-30. Defaults to 28. progress (gr.Progress, optional): Gradio progress tracker for monitoring generation progress. Defaults to gr.Progress(track_tqdm=True). Returns: tuple: A 3-tuple containing: - PIL.Image.Image: The generated/edited image - int: The seed value used for generation (useful when randomize_seed=True) - gr.update: Gradio update object to make the reuse button visible Example: >>> edited_image, used_seed, button_update = infer( ... input_image=my_image, ... prompt="Add sunglasses", ... seed=123, ... randomize_seed=False, ... guidance_scale=2.5 ... ) """ progress(0,desc="Starting") def callback_fn(pipe, step, timestep, callback_kwargs): print(f"[Step {step}] Timestep: {timestep}") progress_value = (step+1.0)/steps progress(progress_value, desc=f"Image generating, {step + 1}/{steps} steps") return callback_kwargs if randomize_seed: seed = random.randint(0, MAX_SEED) try: if input_image: input_image = input_image.convert("RGB") # NSFW 检测 if nsfw_model and nsfw_processor: if detect_nsfw(input_image): msg = "The input image contains NSFW content and cannot be generated. Please modify the input image or prompt and try again." raise Exception(msg) image = pipe( image=input_image, prompt=prompt, guidance_scale=guidance_scale, width = input_image.size[0], height = input_image.size[1], num_inference_steps=steps, callback_on_step_end=callback_fn, generator=torch.Generator().manual_seed(seed), ).images[0] else: image = pipe( prompt=prompt, guidance_scale=guidance_scale, num_inference_steps=steps, callback_on_step_end=callback_fn, generator=torch.Generator().manual_seed(seed), ).images[0] # NSFW 检测 if nsfw_model and nsfw_processor: if detect_nsfw(image): msg = "Generated image contains NSFW content and cannot be displayed. Please modify the input image or prompt and try again." raise Exception(msg) progress(1, desc="Complete") info = { "status": "success" } return image, info, seed, gr.Button(visible=True) except GenerationError as e: error_info = { "error": str(e), "status": "failed", } return None, error_info, None, None except Exception as e: error_info = { "error": str(e), "status": "failed", } return None, error_info, None, None def infer(input_image, prompt, seed=42, randomize_seed=False, guidance_scale=2.5, steps=28, progress=gr.Progress()): # 调用 GPU 函数 image, info, seed, reuse_button = _infer(input_image, prompt,seed,randomize_seed,guidance_scale,steps,progress) # 如果出错,抛出异常 if info["status"] == "failed": raise gr.Error(info["error"]) # 返回图片 return image, seed, reuse_button @spaces.GPU def infer_example(input_image, prompt): image, seed, _ = infer(input_image, prompt) return image, seed title = "# Image to Image AI Editor" description = "Your Image-to-Image AI editor. Just describe changes (‘brighter, remove object, cartoon style’) and let the AI handle the rest—no Photoshop skills needed. Try the stable version at [Image to Image AI Generator](https://www.image2image.ai)." with gr.Blocks() as demo: gr.Markdown(title) gr.Markdown(description) with gr.Column(): with gr.Row(): with gr.Column(): input_image = gr.Image(label="Upload the image for editing", type="pil") with gr.Row(): prompt = gr.Text( label="Prompt", show_label=False, max_lines=1, placeholder="Enter your prompt for editing (e.g., 'Remove glasses', 'Add a hat')", container=False, ) run_button = gr.Button("Run", scale=0) with gr.Accordion("Advanced Settings", open=False): seed = gr.Slider( label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=0, ) randomize_seed = gr.Checkbox(label="Randomize seed", value=True) guidance_scale = gr.Slider( label="Guidance Scale", minimum=1, maximum=10, step=0.1, value=2.5, ) steps = gr.Slider( label="Steps", minimum=1, maximum=30, value=28, step=1 ) with gr.Column(): result = gr.Image(label="Result", show_label=False, interactive=False) reuse_button = gr.Button("Reuse this image", visible=False) examples = gr.Examples( examples=[ ["flowers.png", "turn the flowers into sunflowers"], ["monster.png", "make this monster ride a skateboard on the beach"], ["cat.png", "make this cat happy"] ], inputs=[input_image, prompt], outputs=[result, seed], fn=infer_example, cache_examples="lazy" ) gr.on( triggers=[run_button.click, prompt.submit], fn = infer, inputs = [input_image, prompt, seed, randomize_seed, guidance_scale, steps], outputs = [result, seed, reuse_button] ) reuse_button.click( fn = lambda image: image, inputs = [result], outputs = [input_image] ) demo.launch(mcp_server=True)