frogleo commited on
Commit
4accd7e
·
verified ·
1 Parent(s): b42cbda

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +93 -24
app.py CHANGED
@@ -9,16 +9,54 @@ import spaces
9
  import torch
10
  import random
11
  from PIL import Image
 
12
 
13
  from diffusers import FluxKontextPipeline
14
  from diffusers.utils import load_image
15
 
 
 
 
 
 
 
 
 
16
  MAX_SEED = np.iinfo(np.int32).max
17
 
 
 
 
 
18
  pipe = FluxKontextPipeline.from_pretrained("black-forest-labs/FLUX.1-Kontext-dev", torch_dtype=torch.bfloat16).to("cuda")
19
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20
  @spaces.GPU
21
- def infer(input_image, prompt, seed=42, randomize_seed=False, guidance_scale=2.5, steps=28, progress=gr.Progress()):
22
  """
23
  Perform image editing using the FLUX.1 Kontext pipeline.
24
 
@@ -69,30 +107,61 @@ def infer(input_image, prompt, seed=42, randomize_seed=False, guidance_scale=2.5
69
 
70
  if randomize_seed:
71
  seed = random.randint(0, MAX_SEED)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
72
 
73
- if input_image:
74
- input_image = input_image.convert("RGB")
75
- image = pipe(
76
- image=input_image,
77
- prompt=prompt,
78
- guidance_scale=guidance_scale,
79
- width = input_image.size[0],
80
- height = input_image.size[1],
81
- num_inference_steps=steps,
82
- callback_on_step_end=callback_fn,
83
- generator=torch.Generator().manual_seed(seed),
84
- ).images[0]
85
- else:
86
- image = pipe(
87
- prompt=prompt,
88
- guidance_scale=guidance_scale,
89
- num_inference_steps=steps,
90
- callback_on_step_end=callback_fn,
91
- generator=torch.Generator().manual_seed(seed),
92
- ).images[0]
93
-
94
- progress(1, desc="Complete")
95
- return image, seed, gr.Button(visible=True)
96
 
97
  @spaces.GPU
98
  def infer_example(input_image, prompt):
 
9
  import torch
10
  import random
11
  from PIL import Image
12
+ import logging
13
 
14
  from diffusers import FluxKontextPipeline
15
  from diffusers.utils import load_image
16
 
17
+ # Enhanced logging configuration
18
+ logging.basicConfig(
19
+ level=logging.INFO,
20
+ format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
21
+ datefmt='%Y-%m-%d %H:%M:%S'
22
+ )
23
+ logger = logging.getLogger(__name__)
24
+
25
  MAX_SEED = np.iinfo(np.int32).max
26
 
27
+ class GenerationError(Exception):
28
+ """Custom exception for generation errors"""
29
+ pass
30
+
31
  pipe = FluxKontextPipeline.from_pretrained("black-forest-labs/FLUX.1-Kontext-dev", torch_dtype=torch.bfloat16).to("cuda")
32
 
33
+ # -------------------- NSFW 检测模型加载 --------------------
34
+ try:
35
+ logger.info("Loading NSFW detector...")
36
+ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
37
+ from transformers import AutoProcessor, AutoModelForImageClassification
38
+ nsfw_processor = AutoProcessor.from_pretrained("Falconsai/nsfw_image_detection")
39
+ nsfw_model = AutoModelForImageClassification.from_pretrained(
40
+ "Falconsai/nsfw_image_detection"
41
+ ).to(device)
42
+ logger.info("NSFW detector loaded successfully.")
43
+ except Exception as e:
44
+ logger.error(f"Failed to load NSFW detector: {e}")
45
+ nsfw_model = None
46
+ nsfw_processor = None
47
+
48
+ def detect_nsfw(image: Image.Image, threshold: float = 0.5) -> bool:
49
+ """Returns True if image is NSFW"""
50
+ inputs = nsfw_processor(images=image, return_tensors="pt").to(device)
51
+ with torch.no_grad():
52
+ outputs = nsfw_model(**inputs)
53
+ probs = torch.nn.functional.softmax(outputs.logits, dim=-1)
54
+ nsfw_score = probs[0][1].item() # label 1 = NSFW
55
+ return nsfw_score > threshold
56
+
57
+
58
  @spaces.GPU
59
+ def _infer(input_image, prompt, seed=42, randomize_seed=False, guidance_scale=2.5, steps=28, progress=gr.Progress()):
60
  """
61
  Perform image editing using the FLUX.1 Kontext pipeline.
62
 
 
107
 
108
  if randomize_seed:
109
  seed = random.randint(0, MAX_SEED)
110
+
111
+ try:
112
+ if input_image:
113
+ input_image = input_image.convert("RGB")
114
+ image = pipe(
115
+ image=input_image,
116
+ prompt=prompt,
117
+ guidance_scale=guidance_scale,
118
+ width = input_image.size[0],
119
+ height = input_image.size[1],
120
+ num_inference_steps=steps,
121
+ callback_on_step_end=callback_fn,
122
+ generator=torch.Generator().manual_seed(seed),
123
+ ).images[0]
124
+ else:
125
+ image = pipe(
126
+ prompt=prompt,
127
+ guidance_scale=guidance_scale,
128
+ num_inference_steps=steps,
129
+ callback_on_step_end=callback_fn,
130
+ generator=torch.Generator().manual_seed(seed),
131
+ ).images[0]
132
+ # NSFW 检测
133
+ if nsfw_model and nsfw_processor:
134
+ if detect_nsfw(image):
135
+ msg = "Generated image contains NSFW content and cannot be displayed. Please modify your prompt and try again."
136
+ raise Exception(msg)
137
+
138
+ progress(1, desc="Complete")
139
+ info = {
140
+ "status": "success"
141
+ }
142
+ return image, info, seed, gr.Button(visible=True)
143
+ except GenerationError as e:
144
+ error_info = {
145
+ "error": str(e),
146
+ "status": "failed",
147
+ }
148
+ return None, error_info, None, None
149
+ except Exception as e:
150
+ error_info = {
151
+ "error": str(e),
152
+ "status": "failed",
153
+ }
154
+ return None, error_info, None, None
155
+
156
 
157
+ def infer(input_image, prompt, seed=42, randomize_seed=False, guidance_scale=2.5, steps=28, progress=gr.Progress()):
158
+ # 调用 GPU 函数
159
+ image, info, seed, reuse_button = _infer(input_image, prompt,seed,randomize_seed,guidance_scale,steps,progress)
160
+ # 如果出错,抛出异常
161
+ if info["status"] == "failed":
162
+ raise gr.Error(info["error"])
163
+ # 返回图片
164
+ return image, seed, reuse_button
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
165
 
166
  @spaces.GPU
167
  def infer_example(input_image, prompt):