Spaces:
Runtime error
Runtime error
fix md and pipeline
Browse files- app/business_logic.py +12 -12
- app/ui_components.py +7 -7
- ic_custom/pipelines/ic_custom_pipeline.py +14 -1
- ic_custom/utils/model_utils.py +26 -4
app/business_logic.py
CHANGED
|
@@ -412,8 +412,8 @@ def change_custmization_mode(custmization_mode, input_mask_mode):
|
|
| 412 |
gr.update(value="<s>Select a input mask mode</s>", visible=False),
|
| 413 |
gr.update(value="<s>Input target image & mask (Iterate clicking or brushing until the target is covered)</s>", visible=False),
|
| 414 |
gr.update(value="<s>View or modify the target mask</s>", visible=False),
|
| 415 |
-
gr.update(value="3
|
| 416 |
-
gr.update(value="4
|
| 417 |
gr.update(visible=False),
|
| 418 |
gr.update(visible=False),
|
| 419 |
|
|
@@ -426,11 +426,11 @@ def change_custmization_mode(custmization_mode, input_mask_mode):
|
|
| 426 |
gr.update(interactive=True, visible=True),
|
| 427 |
gr.update(interactive=True, visible=True),
|
| 428 |
gr.update(interactive=True, visible=True),
|
| 429 |
-
gr.update(value="3
|
| 430 |
-
gr.update(value="4
|
| 431 |
-
gr.update(value="6
|
| 432 |
-
gr.update(value="5
|
| 433 |
-
gr.update(value="7
|
| 434 |
gr.update(visible=True, value="Precise mask"),
|
| 435 |
gr.update(visible=True),
|
| 436 |
)
|
|
@@ -441,11 +441,11 @@ def change_custmization_mode(custmization_mode, input_mask_mode):
|
|
| 441 |
gr.update(interactive=True, visible=True),
|
| 442 |
gr.update(interactive=True, visible=True),
|
| 443 |
gr.update(interactive=True, visible=True),
|
| 444 |
-
gr.update(value="3
|
| 445 |
-
gr.update(value="4
|
| 446 |
-
gr.update(value="6
|
| 447 |
-
gr.update(value="5
|
| 448 |
-
gr.update(value="7
|
| 449 |
gr.update(visible=True, value="User-drawn mask"),
|
| 450 |
gr.update(visible=True),
|
| 451 |
)
|
|
|
|
| 412 |
gr.update(value="<s>Select a input mask mode</s>", visible=False),
|
| 413 |
gr.update(value="<s>Input target image & mask (Iterate clicking or brushing until the target is covered)</s>", visible=False),
|
| 414 |
gr.update(value="<s>View or modify the target mask</s>", visible=False),
|
| 415 |
+
gr.update(value="3\. Input text prompt (necessary)"),
|
| 416 |
+
gr.update(value="4\. Submit and view the output"),
|
| 417 |
gr.update(visible=False),
|
| 418 |
gr.update(visible=False),
|
| 419 |
|
|
|
|
| 426 |
gr.update(interactive=True, visible=True),
|
| 427 |
gr.update(interactive=True, visible=True),
|
| 428 |
gr.update(interactive=True, visible=True),
|
| 429 |
+
gr.update(value="3\. Select a input mask mode", visible=True),
|
| 430 |
+
gr.update(value="4\. Input target image & mask (Iterate clicking or brushing until the target is covered)", visible=True),
|
| 431 |
+
gr.update(value="6\. View or modify the target mask", visible=True),
|
| 432 |
+
gr.update(value="5\. Input text prompt (optional)", visible=True),
|
| 433 |
+
gr.update(value="7\. Submit and view the output", visible=True),
|
| 434 |
gr.update(visible=True, value="Precise mask"),
|
| 435 |
gr.update(visible=True),
|
| 436 |
)
|
|
|
|
| 441 |
gr.update(interactive=True, visible=True),
|
| 442 |
gr.update(interactive=True, visible=True),
|
| 443 |
gr.update(interactive=True, visible=True),
|
| 444 |
+
gr.update(value="3\. Select a input mask mode", visible=True),
|
| 445 |
+
gr.update(value="4\. Input target image & mask (Iterate clicking or brushing until the target is covered)", visible=True),
|
| 446 |
+
gr.update(value="6\. View or modify the target mask", visible=True),
|
| 447 |
+
gr.update(value="5\. Input text prompt (optional)", visible=True),
|
| 448 |
+
gr.update(value="7\. Submit and view the output", visible=True),
|
| 449 |
gr.update(visible=True, value="User-drawn mask"),
|
| 450 |
gr.update(visible=True),
|
| 451 |
)
|
app/ui_components.py
CHANGED
|
@@ -44,7 +44,7 @@ def create_customization_section():
|
|
| 44 |
with gr.Row():
|
| 45 |
# Add a note to remind users to click Clear before starting
|
| 46 |
md_custmization_mode = gr.Markdown(
|
| 47 |
-
"1
|
| 48 |
)
|
| 49 |
with gr.Row():
|
| 50 |
custmization_mode = gr.Radio(
|
|
@@ -61,7 +61,7 @@ def create_customization_section():
|
|
| 61 |
def create_image_input_section():
|
| 62 |
"""Create image input section optimized for left column layout."""
|
| 63 |
# Reference image section
|
| 64 |
-
md_image_reference = gr.Markdown("2
|
| 65 |
with gr.Group():
|
| 66 |
image_reference = gr.Image(
|
| 67 |
label="Reference Image",
|
|
@@ -73,7 +73,7 @@ def create_image_input_section():
|
|
| 73 |
)
|
| 74 |
|
| 75 |
# Input mask mode selection
|
| 76 |
-
md_input_mask_mode = gr.Markdown("3
|
| 77 |
with gr.Group():
|
| 78 |
input_mask_mode = gr.Radio(
|
| 79 |
["Precise mask", "User-drawn mask"],
|
|
@@ -84,7 +84,7 @@ def create_image_input_section():
|
|
| 84 |
)
|
| 85 |
|
| 86 |
# Target image section
|
| 87 |
-
md_target_image = gr.Markdown("4
|
| 88 |
|
| 89 |
# Precise mask mode
|
| 90 |
with gr.Group():
|
|
@@ -129,7 +129,7 @@ def create_image_input_section():
|
|
| 129 |
|
| 130 |
def create_prompt_section():
|
| 131 |
"""Create the text prompt input section with improved layout."""
|
| 132 |
-
md_prompt = gr.Markdown("5
|
| 133 |
with gr.Group():
|
| 134 |
prompt = gr.Textbox(
|
| 135 |
placeholder="Please input the description for the target scene.",
|
|
@@ -243,7 +243,7 @@ def create_advanced_options_section():
|
|
| 243 |
|
| 244 |
def create_mask_operation_section():
|
| 245 |
"""Create mask operation section optimized for right column (outputs)."""
|
| 246 |
-
md_mask_operation = gr.Markdown("6
|
| 247 |
|
| 248 |
with gr.Group():
|
| 249 |
# Mask gallery with responsive layout
|
|
@@ -293,7 +293,7 @@ def create_mask_operation_section():
|
|
| 293 |
|
| 294 |
def create_output_section():
|
| 295 |
"""Create the output section optimized for right column."""
|
| 296 |
-
md_submit = gr.Markdown("7
|
| 297 |
|
| 298 |
# Generation controls at top for better workflow
|
| 299 |
with gr.Group():
|
|
|
|
| 44 |
with gr.Row():
|
| 45 |
# Add a note to remind users to click Clear before starting
|
| 46 |
md_custmization_mode = gr.Markdown(
|
| 47 |
+
"1\. Select a Customization Mode\n\n*Tip: Please click the Clear button first to reset all states before starting a new task.*"
|
| 48 |
)
|
| 49 |
with gr.Row():
|
| 50 |
custmization_mode = gr.Radio(
|
|
|
|
| 61 |
def create_image_input_section():
|
| 62 |
"""Create image input section optimized for left column layout."""
|
| 63 |
# Reference image section
|
| 64 |
+
md_image_reference = gr.Markdown("2\. Input reference image")
|
| 65 |
with gr.Group():
|
| 66 |
image_reference = gr.Image(
|
| 67 |
label="Reference Image",
|
|
|
|
| 73 |
)
|
| 74 |
|
| 75 |
# Input mask mode selection
|
| 76 |
+
md_input_mask_mode = gr.Markdown("3\. Select input mask mode")
|
| 77 |
with gr.Group():
|
| 78 |
input_mask_mode = gr.Radio(
|
| 79 |
["Precise mask", "User-drawn mask"],
|
|
|
|
| 84 |
)
|
| 85 |
|
| 86 |
# Target image section
|
| 87 |
+
md_target_image = gr.Markdown("4\. Input target image & mask (Iterate clicking or brushing until the target is covered)")
|
| 88 |
|
| 89 |
# Precise mask mode
|
| 90 |
with gr.Group():
|
|
|
|
| 129 |
|
| 130 |
def create_prompt_section():
|
| 131 |
"""Create the text prompt input section with improved layout."""
|
| 132 |
+
md_prompt = gr.Markdown("5\. Input text prompt (optional)")
|
| 133 |
with gr.Group():
|
| 134 |
prompt = gr.Textbox(
|
| 135 |
placeholder="Please input the description for the target scene.",
|
|
|
|
| 243 |
|
| 244 |
def create_mask_operation_section():
|
| 245 |
"""Create mask operation section optimized for right column (outputs)."""
|
| 246 |
+
md_mask_operation = gr.Markdown("6\. View or modify the target mask")
|
| 247 |
|
| 248 |
with gr.Group():
|
| 249 |
# Mask gallery with responsive layout
|
|
|
|
| 293 |
|
| 294 |
def create_output_section():
|
| 295 |
"""Create the output section optimized for right column."""
|
| 296 |
+
md_submit = gr.Markdown("7\. Submit and view the output")
|
| 297 |
|
| 298 |
# Generation controls at top for better workflow
|
| 299 |
with gr.Group():
|
ic_custom/pipelines/ic_custom_pipeline.py
CHANGED
|
@@ -1,4 +1,4 @@
|
|
| 1 |
-
|
| 2 |
import re
|
| 3 |
from typing import List, Optional, Union
|
| 4 |
|
|
@@ -128,6 +128,10 @@ class ICCustomPipeline:
|
|
| 128 |
double_blocks_idx: str = None,
|
| 129 |
single_blocks_idx: str = None,
|
| 130 |
):
|
|
|
|
|
|
|
|
|
|
|
|
|
| 131 |
lora_path = resolve_model_path(
|
| 132 |
name=lora_path,
|
| 133 |
repo_id_field="repo_id",
|
|
@@ -181,6 +185,9 @@ class ICCustomPipeline:
|
|
| 181 |
self.load_model_weights(weights, strict=False)
|
| 182 |
|
| 183 |
def set_img_txt_in(self, img_txt_in_path: str):
|
|
|
|
|
|
|
|
|
|
| 184 |
img_txt_in_path = resolve_model_path(
|
| 185 |
name=img_txt_in_path,
|
| 186 |
repo_id_field="repo_id",
|
|
@@ -192,6 +199,9 @@ class ICCustomPipeline:
|
|
| 192 |
self.load_model_weights(weights, strict=False)
|
| 193 |
|
| 194 |
def set_boundary_embeddings(self, boundary_embeddings_path: str):
|
|
|
|
|
|
|
|
|
|
| 195 |
boundary_embeddings_path = resolve_model_path(
|
| 196 |
name=boundary_embeddings_path,
|
| 197 |
repo_id_field="repo_id",
|
|
@@ -203,6 +213,9 @@ class ICCustomPipeline:
|
|
| 203 |
self.load_model_weights(weights, strict=False)
|
| 204 |
|
| 205 |
def set_task_register_embeddings(self, task_register_embeddings_path: str):
|
|
|
|
|
|
|
|
|
|
| 206 |
task_register_embeddings_path = resolve_model_path(
|
| 207 |
name=task_register_embeddings_path,
|
| 208 |
repo_id_field="repo_id",
|
|
|
|
| 1 |
+
import os
|
| 2 |
import re
|
| 3 |
from typing import List, Optional, Union
|
| 4 |
|
|
|
|
| 128 |
double_blocks_idx: str = None,
|
| 129 |
single_blocks_idx: str = None,
|
| 130 |
):
|
| 131 |
+
if not os.path.exists(lora_path):
|
| 132 |
+
lora_path = "dit_lora_0x1561"
|
| 133 |
+
|
| 134 |
+
|
| 135 |
lora_path = resolve_model_path(
|
| 136 |
name=lora_path,
|
| 137 |
repo_id_field="repo_id",
|
|
|
|
| 185 |
self.load_model_weights(weights, strict=False)
|
| 186 |
|
| 187 |
def set_img_txt_in(self, img_txt_in_path: str):
|
| 188 |
+
if not os.path.exists(img_txt_in_path):
|
| 189 |
+
img_txt_in_path = "dit_txt_img_in_0x1561"
|
| 190 |
+
|
| 191 |
img_txt_in_path = resolve_model_path(
|
| 192 |
name=img_txt_in_path,
|
| 193 |
repo_id_field="repo_id",
|
|
|
|
| 199 |
self.load_model_weights(weights, strict=False)
|
| 200 |
|
| 201 |
def set_boundary_embeddings(self, boundary_embeddings_path: str):
|
| 202 |
+
if not os.path.exists(boundary_embeddings_path):
|
| 203 |
+
boundary_embeddings_path = "dit_boundary_embeddings_0x1561"
|
| 204 |
+
|
| 205 |
boundary_embeddings_path = resolve_model_path(
|
| 206 |
name=boundary_embeddings_path,
|
| 207 |
repo_id_field="repo_id",
|
|
|
|
| 213 |
self.load_model_weights(weights, strict=False)
|
| 214 |
|
| 215 |
def set_task_register_embeddings(self, task_register_embeddings_path: str):
|
| 216 |
+
if not os.path.exists(task_register_embeddings_path):
|
| 217 |
+
task_register_embeddings_path = "dit_task_register_embeddings_0x1561"
|
| 218 |
+
|
| 219 |
task_register_embeddings_path = resolve_model_path(
|
| 220 |
name=task_register_embeddings_path,
|
| 221 |
repo_id_field="repo_id",
|
ic_custom/utils/model_utils.py
CHANGED
|
@@ -206,6 +206,9 @@ def load_dit(
|
|
| 206 |
model: Loaded Flux model
|
| 207 |
"""
|
| 208 |
# Loading Flux
|
|
|
|
|
|
|
|
|
|
| 209 |
logger.info("Initializing Flux model")
|
| 210 |
|
| 211 |
# Resolve checkpoint path
|
|
@@ -249,9 +252,11 @@ def load_ic_custom(
|
|
| 249 |
model: Loaded IC_Custom model
|
| 250 |
"""
|
| 251 |
logger.info("Initializing IC-Custom model")
|
| 252 |
-
|
| 253 |
# Resolve checkpoint path
|
| 254 |
-
|
|
|
|
|
|
|
| 255 |
ckpt_path = resolve_model_path(
|
| 256 |
name=name,
|
| 257 |
repo_id_field="repo_id",
|
|
@@ -312,8 +317,7 @@ def load_embedder(
|
|
| 312 |
path,
|
| 313 |
max_length=max_length,
|
| 314 |
is_clip=is_clip,
|
| 315 |
-
|
| 316 |
-
).to(device)
|
| 317 |
|
| 318 |
return model
|
| 319 |
|
|
@@ -336,7 +340,11 @@ def load_t5(
|
|
| 336 |
Returns:
|
| 337 |
model: Loaded T5 model
|
| 338 |
"""
|
|
|
|
|
|
|
|
|
|
| 339 |
logger.info(f"Loading T5 model: {name}")
|
|
|
|
| 340 |
return load_embedder(
|
| 341 |
name=name,
|
| 342 |
is_clip=False,
|
|
@@ -362,7 +370,11 @@ def load_clip(
|
|
| 362 |
Returns:
|
| 363 |
model: Loaded CLIP model
|
| 364 |
"""
|
|
|
|
|
|
|
|
|
|
| 365 |
logger.info(f"Loading CLIP model: {name}")
|
|
|
|
| 366 |
return load_embedder(
|
| 367 |
name=name,
|
| 368 |
is_clip=True,
|
|
@@ -387,6 +399,10 @@ def load_ae(
|
|
| 387 |
Returns:
|
| 388 |
model: Loaded AutoEncoder model
|
| 389 |
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
| 390 |
logger.info(f"Loading AutoEncoder model: {name}")
|
| 391 |
|
| 392 |
# Convert device string to torch.device if needed
|
|
@@ -429,6 +445,12 @@ def load_redux(
|
|
| 429 |
Returns:
|
| 430 |
model: Loaded Redux Image Encoder model
|
| 431 |
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 432 |
logger.info(f"Loading Redux Image Encoder: redux={redux_name}, siglip={siglip_name}")
|
| 433 |
|
| 434 |
# Convert device string to torch.device if needed
|
|
|
|
| 206 |
model: Loaded Flux model
|
| 207 |
"""
|
| 208 |
# Loading Flux
|
| 209 |
+
if not os.path.exists(name):
|
| 210 |
+
name = "flux-fill-dev-dit"
|
| 211 |
+
|
| 212 |
logger.info("Initializing Flux model")
|
| 213 |
|
| 214 |
# Resolve checkpoint path
|
|
|
|
| 252 |
model: Loaded IC_Custom model
|
| 253 |
"""
|
| 254 |
logger.info("Initializing IC-Custom model")
|
| 255 |
+
|
| 256 |
# Resolve checkpoint path
|
| 257 |
+
if not os.path.exists(name):
|
| 258 |
+
name = "flux-fill-dev-dit"
|
| 259 |
+
|
| 260 |
ckpt_path = resolve_model_path(
|
| 261 |
name=name,
|
| 262 |
repo_id_field="repo_id",
|
|
|
|
| 317 |
path,
|
| 318 |
max_length=max_length,
|
| 319 |
is_clip=is_clip,
|
| 320 |
+
).to(device).to(dtype)
|
|
|
|
| 321 |
|
| 322 |
return model
|
| 323 |
|
|
|
|
| 340 |
Returns:
|
| 341 |
model: Loaded T5 model
|
| 342 |
"""
|
| 343 |
+
if not os.path.exists(name):
|
| 344 |
+
name = "t5-v1_1-xxl"
|
| 345 |
+
|
| 346 |
logger.info(f"Loading T5 model: {name}")
|
| 347 |
+
|
| 348 |
return load_embedder(
|
| 349 |
name=name,
|
| 350 |
is_clip=False,
|
|
|
|
| 370 |
Returns:
|
| 371 |
model: Loaded CLIP model
|
| 372 |
"""
|
| 373 |
+
if not os.path.exists(name):
|
| 374 |
+
name = "clip-vit-large-patch14"
|
| 375 |
+
|
| 376 |
logger.info(f"Loading CLIP model: {name}")
|
| 377 |
+
|
| 378 |
return load_embedder(
|
| 379 |
name=name,
|
| 380 |
is_clip=True,
|
|
|
|
| 399 |
Returns:
|
| 400 |
model: Loaded AutoEncoder model
|
| 401 |
"""
|
| 402 |
+
|
| 403 |
+
if not os.path.exists(name):
|
| 404 |
+
name = "flux-fill-dev-ae"
|
| 405 |
+
|
| 406 |
logger.info(f"Loading AutoEncoder model: {name}")
|
| 407 |
|
| 408 |
# Convert device string to torch.device if needed
|
|
|
|
| 445 |
Returns:
|
| 446 |
model: Loaded Redux Image Encoder model
|
| 447 |
"""
|
| 448 |
+
|
| 449 |
+
if not os.path.exists(redux_name):
|
| 450 |
+
redux_name = "flux1-redux-dev"
|
| 451 |
+
if not os.path.exists(siglip_name):
|
| 452 |
+
siglip_name = "siglip-so400m-patch14-384"
|
| 453 |
+
|
| 454 |
logger.info(f"Loading Redux Image Encoder: redux={redux_name}, siglip={siglip_name}")
|
| 455 |
|
| 456 |
# Convert device string to torch.device if needed
|