frogleo's picture
Update app.py
3421bda verified
# PyTorch 2.8 (temporary hack)
import os
os.system('pip install --upgrade --pre --extra-index-url https://download.pytorch.org/whl/nightly/cu126 "torch<2.9" spaces')
# Actual demo code
import gradio as gr
import numpy as np
import spaces
import torch
import random
from PIL import Image
import logging
from diffusers import FluxKontextPipeline
from diffusers.utils import load_image
# Enhanced logging configuration
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
datefmt='%Y-%m-%d %H:%M:%S'
)
logger = logging.getLogger(__name__)
MAX_SEED = np.iinfo(np.int32).max
class GenerationError(Exception):
"""Custom exception for generation errors"""
pass
pipe = FluxKontextPipeline.from_pretrained("black-forest-labs/FLUX.1-Kontext-dev", torch_dtype=torch.bfloat16).to("cuda")
# -------------------- NSFW 检测模型加载 --------------------
try:
logger.info("Loading NSFW detector...")
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
from transformers import AutoProcessor, AutoModelForImageClassification
nsfw_processor = AutoProcessor.from_pretrained("Falconsai/nsfw_image_detection")
nsfw_model = AutoModelForImageClassification.from_pretrained(
"Falconsai/nsfw_image_detection"
).to(device)
logger.info("NSFW detector loaded successfully.")
except Exception as e:
logger.error(f"Failed to load NSFW detector: {e}")
nsfw_model = None
nsfw_processor = None
def detect_nsfw(image: Image.Image, threshold: float = 0.5) -> bool:
"""Returns True if image is NSFW"""
inputs = nsfw_processor(images=image, return_tensors="pt").to(device)
with torch.no_grad():
outputs = nsfw_model(**inputs)
probs = torch.nn.functional.softmax(outputs.logits, dim=-1)
nsfw_score = probs[0][1].item() # label 1 = NSFW
return nsfw_score > threshold
@spaces.GPU
def _infer(input_image, prompt, seed=42, randomize_seed=False, guidance_scale=2.5, steps=28, progress=gr.Progress()):
"""
Perform image editing using the FLUX.1 Kontext pipeline.
This function takes an input image and a text prompt to generate a modified version
of the image based on the provided instructions. It uses the FLUX.1 Kontext model
for contextual image editing tasks.
Args:
input_image (PIL.Image.Image): The input image to be edited. Will be converted
to RGB format if not already in that format.
prompt (str): Text description of the desired edit to apply to the image.
Examples: "Remove glasses", "Add a hat", "Change background to beach".
seed (int, optional): Random seed for reproducible generation. Defaults to 42.
Must be between 0 and MAX_SEED (2^31 - 1).
randomize_seed (bool, optional): If True, generates a random seed instead of
using the provided seed value. Defaults to False.
guidance_scale (float, optional): Controls how closely the model follows the
prompt. Higher values mean stronger adherence to the prompt but may reduce
image quality. Range: 1.0-10.0. Defaults to 2.5.
steps (int, optional): Controls how many steps to run the diffusion model for.
Range: 1-30. Defaults to 28.
progress (gr.Progress, optional): Gradio progress tracker for monitoring
generation progress. Defaults to gr.Progress(track_tqdm=True).
Returns:
tuple: A 3-tuple containing:
- PIL.Image.Image: The generated/edited image
- int: The seed value used for generation (useful when randomize_seed=True)
- gr.update: Gradio update object to make the reuse button visible
Example:
>>> edited_image, used_seed, button_update = infer(
... input_image=my_image,
... prompt="Add sunglasses",
... seed=123,
... randomize_seed=False,
... guidance_scale=2.5
... )
"""
progress(0,desc="Starting")
def callback_fn(pipe, step, timestep, callback_kwargs):
print(f"[Step {step}] Timestep: {timestep}")
progress_value = (step+1.0)/steps
progress(progress_value, desc=f"Image generating, {step + 1}/{steps} steps")
return callback_kwargs
if randomize_seed:
seed = random.randint(0, MAX_SEED)
try:
if input_image:
input_image = input_image.convert("RGB")
# NSFW 检测
if nsfw_model and nsfw_processor:
if detect_nsfw(input_image):
msg = "The input image contains NSFW content and cannot be generated. Please modify the input image or prompt and try again."
raise Exception(msg)
image = pipe(
image=input_image,
prompt=prompt,
guidance_scale=guidance_scale,
width = input_image.size[0],
height = input_image.size[1],
num_inference_steps=steps,
callback_on_step_end=callback_fn,
generator=torch.Generator().manual_seed(seed),
).images[0]
else:
image = pipe(
prompt=prompt,
guidance_scale=guidance_scale,
num_inference_steps=steps,
callback_on_step_end=callback_fn,
generator=torch.Generator().manual_seed(seed),
).images[0]
# NSFW 检测
if nsfw_model and nsfw_processor:
if detect_nsfw(image):
msg = "Generated image contains NSFW content and cannot be displayed. Please modify the input image or prompt and try again."
raise Exception(msg)
progress(1, desc="Complete")
info = {
"status": "success"
}
return image, info, seed, gr.Button(visible=True)
except GenerationError as e:
error_info = {
"error": str(e),
"status": "failed",
}
return None, error_info, None, None
except Exception as e:
error_info = {
"error": str(e),
"status": "failed",
}
return None, error_info, None, None
def infer(input_image, prompt, seed=42, randomize_seed=False, guidance_scale=2.5, steps=28, progress=gr.Progress()):
# 调用 GPU 函数
image, info, seed, reuse_button = _infer(input_image, prompt,seed,randomize_seed,guidance_scale,steps,progress)
# 如果出错,抛出异常
if info["status"] == "failed":
raise gr.Error(info["error"])
# 返回图片
return image, seed, reuse_button
@spaces.GPU
def infer_example(input_image, prompt):
image, seed, _ = infer(input_image, prompt)
return image, seed
title = "# Image to Image AI Editor"
description = "Your Image-to-Image AI editor. Just describe changes (‘brighter, remove object, cartoon style’) and let the AI handle the rest—no Photoshop skills needed. Try the stable version at [Image to Image AI Generator](https://www.image2image.ai)."
with gr.Blocks() as demo:
gr.Markdown(title)
gr.Markdown(description)
with gr.Column():
with gr.Row():
with gr.Column():
input_image = gr.Image(label="Upload the image for editing", type="pil")
with gr.Row():
prompt = gr.Text(
label="Prompt",
show_label=False,
max_lines=1,
placeholder="Enter your prompt for editing (e.g., 'Remove glasses', 'Add a hat')",
container=False,
)
run_button = gr.Button("Run", scale=0)
with gr.Accordion("Advanced Settings", open=False):
seed = gr.Slider(
label="Seed",
minimum=0,
maximum=MAX_SEED,
step=1,
value=0,
)
randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
guidance_scale = gr.Slider(
label="Guidance Scale",
minimum=1,
maximum=10,
step=0.1,
value=2.5,
)
steps = gr.Slider(
label="Steps",
minimum=1,
maximum=30,
value=28,
step=1
)
with gr.Column():
result = gr.Image(label="Result", show_label=False, interactive=False)
reuse_button = gr.Button("Reuse this image", visible=False)
examples = gr.Examples(
examples=[
["flowers.png", "turn the flowers into sunflowers"],
["monster.png", "make this monster ride a skateboard on the beach"],
["cat.png", "make this cat happy"]
],
inputs=[input_image, prompt],
outputs=[result, seed],
fn=infer_example,
cache_examples="lazy"
)
gr.on(
triggers=[run_button.click, prompt.submit],
fn = infer,
inputs = [input_image, prompt, seed, randomize_seed, guidance_scale, steps],
outputs = [result, seed, reuse_button]
)
reuse_button.click(
fn = lambda image: image,
inputs = [result],
outputs = [input_image]
)
demo.launch(mcp_server=True)