Spaces:
Runtime error
Runtime error
| import os | |
| import copy | |
| #import spaces | |
| from main import run_main | |
| from PIL import Image | |
| import matplotlib | |
| import numpy as np | |
| import gradio as gr | |
| from utils import load_mask, load_mask_edit | |
| from utils_mask import process_mask_to_follow_priority, mask_union, visualize_mask_list_clean | |
| from pathlib import Path | |
| from PIL import Image | |
| from functools import partial | |
| import time | |
| LENGTH=512 #length of the square area displaying/editing images | |
| TRANSPARENCY = 150 # transparency of the mask in display | |
| def add_mask(mask_np_list_updated, mask_label_list): | |
| mask_new = np.zeros_like(mask_np_list_updated[0]) | |
| mask_np_list_updated.append(mask_new) | |
| mask_label_list.append("new") | |
| return mask_np_list_updated, mask_label_list | |
| def create_segmentation(mask_np_list): | |
| viridis = matplotlib.pyplot.get_cmap(name = 'viridis', lut = len(mask_np_list)) | |
| segmentation = 0 | |
| for i, m in enumerate(mask_np_list): | |
| color = matplotlib.colors.to_rgb(viridis(i)) | |
| color_mat = np.ones_like(m) | |
| color_mat = np.stack([color_mat*color[0], color_mat*color[1],color_mat*color[2] ], axis = 2) | |
| color_mat = color_mat * m[:,:,np.newaxis] | |
| segmentation += color_mat | |
| segmentation = Image.fromarray(np.uint8(segmentation*255)) | |
| return segmentation | |
| #@spaces.GPU | |
| def run_segmentation_wrapper(image): | |
| try: | |
| print(image.shape) | |
| image, mask_np_list,mask_label_list = run_segmentation(image) | |
| #image = image.convert('RGB') | |
| segmentation = create_segmentation(mask_np_list) | |
| print("!!", len(mask_np_list)) | |
| max_val = len(mask_np_list)-1 | |
| sliderup = gr.Slider(value = 0, minimum=0, maximum=max_val, step=1, visible=True) | |
| gr.Info('Segmentation finish. Select mask id and move to the next step.') | |
| return image, segmentation, mask_np_list, mask_label_list, image, sliderup, sliderup , 'Segmentation finish. Select mask id and move to the next step.' | |
| except Exception as e: | |
| print(e) | |
| sliderup = gr.Slider(value = 0, minimum=0, maximum=1, step=1, visible=False) | |
| gr.Warning('Please upload an image before proceeding.') | |
| return None,None,None,None,None, sliderup, sliderup , 'Please upload an image before proceeding.' | |
| def transparent_paste_with_mask(backimg, foreimg, mask_np,transparency = 128): | |
| backimg_solid_np = np.array(backimg) | |
| bimg = backimg.copy() | |
| fimg = foreimg.copy() | |
| fimg.putalpha(transparency) | |
| bimg.paste(fimg, (0,0), fimg) | |
| bimg_np = np.array(bimg) | |
| mask_np = mask_np[:,:,np.newaxis] | |
| new_img_np = bimg_np*mask_np + (1-mask_np)* backimg_solid_np | |
| return Image.fromarray(np.uint8(new_img_np)) | |
| def show_segmentation(image, segmentation, flag): | |
| if flag is False: | |
| flag = True | |
| mask_np = np.ones([image.size[0],image.size[1]]).astype(np.uint8) | |
| image_edit = transparent_paste_with_mask(image, segmentation, mask_np ,transparency = TRANSPARENCY) | |
| return image_edit, flag | |
| else: | |
| flag = False | |
| return image,flag | |
| def edit_mask_add(canvas, image, idx, mask_np_list): | |
| mask_sel = mask_np_list[idx] | |
| mask_new = np.uint8(canvas["mask"][:, :, 0]/ 255.) | |
| mask_np_list_updated = [] | |
| for midx, m in enumerate(mask_np_list): | |
| if midx == idx: | |
| mask_np_list_updated.append(mask_union(mask_sel, mask_new)) | |
| else: | |
| mask_np_list_updated.append(m) | |
| priority_list = [0 for _ in range(len(mask_np_list_updated))] | |
| priority_list[idx] = 1 | |
| mask_np_list_updated = process_mask_to_follow_priority(mask_np_list_updated, priority_list) | |
| mask_ones = np.ones([mask_sel.shape[0], mask_sel.shape[1]]).astype(np.uint8) | |
| segmentation = create_segmentation(mask_np_list_updated) | |
| image_edit = transparent_paste_with_mask(image, segmentation, mask_ones ,transparency = TRANSPARENCY) | |
| return mask_np_list_updated, image_edit | |
| def slider_release(index, image, mask_np_list_updated, mask_label_list): | |
| if index > len(mask_np_list_updated)-1: | |
| return image, "out of range", "" | |
| else: | |
| mask_np = mask_np_list_updated[index] | |
| mask_label = mask_label_list[index] | |
| index = mask_label.rfind('-') | |
| mask_label = mask_label[:index] | |
| if mask_label == 'handbag': | |
| mask_prompt = "white handbag" | |
| elif mask_label == 'person': | |
| mask_prompt = "little boy" | |
| elif mask_label == 'wall-other-merged': | |
| mask_prompt = "white wall" | |
| elif mask_label == 'table-merged': | |
| mask_prompt = "table" | |
| else: | |
| mask_prompt = mask_label | |
| segmentation = create_segmentation(mask_np_list_updated) | |
| new_image = transparent_paste_with_mask(image, segmentation, mask_np, transparency = TRANSPARENCY) | |
| gr.Info('Edit '+ mask_label) | |
| return new_image, mask_label, mask_prompt | |
| def image_change(): | |
| return gr.Slider(value = 0, minimum=0, maximum=1, step=1, visible=False) | |
| def save_as_orig_mask(mask_np_list_updated, mask_label_list, input_folder="example_tmp"): | |
| print(mask_np_list_updated) | |
| try: | |
| assert np.all(sum(mask_np_list_updated)==1) | |
| except: | |
| print("please check mask") | |
| # plt.imsave( "out_mask.png", mask_list_edit[0]) | |
| import pdb; pdb.set_trace() | |
| for midx, (mask, mask_label) in enumerate(zip(mask_np_list_updated, mask_label_list)): | |
| # np.save(os.path.join(input_folder, "maskEDIT{}_{}.npy".format(midx, mask_label)),mask ) | |
| np.save(os.path.join(input_folder, "mask{}_{}.npy".format(midx, mask_label)),mask ) | |
| savepath = os.path.join(input_folder, "seg_current.png") | |
| visualize_mask_list_clean(mask_np_list_updated, savepath) | |
| def save_as_edit_mask(mask_np_list_updated, mask_label_list, input_folder="example_tmp"): | |
| print(mask_np_list_updated) | |
| try: | |
| assert np.all(sum(mask_np_list_updated)==1) | |
| except: | |
| print("please check mask") | |
| # plt.imsave( "out_mask.png", mask_list_edit[0]) | |
| import pdb; pdb.set_trace() | |
| for midx, (mask, mask_label) in enumerate(zip(mask_np_list_updated, mask_label_list)): | |
| np.save(os.path.join(input_folder, "maskEdited{}_{}.npy".format(midx, mask_label)), mask) | |
| savepath = os.path.join(input_folder, "seg_edited.png") | |
| visualize_mask_list_clean(mask_np_list_updated, savepath) | |
| def button_clickable(is_clickable): | |
| return gr.Button(interactive=is_clickable) | |
| def load_pil_img(): | |
| from PIL import Image | |
| return Image.open("example_tmp/text/out_text_0.png") | |
| def change_image(img): | |
| return None | |
| import shutil | |
| if os.path.isdir("./example_tmp"): | |
| shutil.rmtree("./example_tmp") | |
| from segment import run_segmentation | |
| with gr.Blocks() as demo: | |
| image = gr.State() # store mask | |
| image_loaded = gr.State() | |
| segmentation = gr.State() | |
| mask_np_list = gr.State([]) | |
| mask_label_list = gr.State([]) | |
| mask_np_list_updated = gr.State([]) | |
| true = gr.State(True) | |
| false = gr.State(False) | |
| block_flag = gr.State(0) | |
| num_tokens_global = gr.State(5) | |
| with gr.Row(): | |
| gr.Markdown("""# D-Edit""") | |
| with gr.Row(): | |
| with gr.Column(): | |
| canvas = gr.Image(value = None, type="numpy", label="Show Mask", show_label=True, height=LENGTH, width=LENGTH, interactive=True) | |
| example_inps = [['./img.png'],['./img2.png'],['./img3.png'],['./img4.png']] | |
| gr.Examples(examples=example_inps, inputs=[canvas], | |
| label='examples', cache_examples='lazy', outputs=[], | |
| fn=change_image) | |
| gr.Markdown(f"Each image must first undergo segmentation. Afterwards, you can modify the \n mask ID and the prompt for image editing, then proceed with the editing process. \n The link of D-edit paper: [https://arxiv.org/abs/2403.04880v2](https://arxiv.org/abs/2403.04880v2), [https://huggingface.co/papers/2403.04880](https://huggingface.co/papers/2403.04880)") | |
| with gr.Column(): | |
| result_info0 = gr.Text(label="Response") | |
| segment_button = gr.Button("Step 1. Run segmentation") | |
| flag = gr.State(False) | |
| # mask_np_list_updated.value = copy.deepcopy(mask_np_list.value) #!! | |
| mask_np_list_updated = mask_np_list | |
| gr.Markdown("""<p style="text-align: center; font-size: 20px">Edit Mask (Do not change it during the editing process)</p>""") | |
| slider = gr.Slider(0, 20, step=1, label = 'mask id', visible=False) | |
| label = gr.Text(label='label') | |
| result_info = gr.Text(label="Response") | |
| opt_flag = gr.State(0) | |
| gr.Markdown("""<p style="text-align: center; font-size: 20px">Optimization settings</p>""") | |
| with gr.Accordion(label="Advanced settings", open=False): | |
| num_tokens = gr.Number(value="5", label="num tokens to represent each object", interactive= True) | |
| num_tokens_global = num_tokens | |
| embedding_learning_rate = gr.Textbox(value="0.00025", label="Embedding optimization: Learning rate", interactive= True ) | |
| max_emb_train_steps = gr.Number(value="6", label="embedding optimization: Training steps", interactive= True ) | |
| diffusion_model_learning_rate = gr.Textbox(value="0.0002", label="UNet Optimization: Learning rate", interactive= True ) | |
| max_diffusion_train_steps = gr.Number(value="28", label="UNet Optimization: Learning rate: Training steps", interactive= True ) | |
| train_batch_size = gr.Number(value="20", label="Batch size", interactive= True ) | |
| gradient_accumulation_steps=gr.Number(value="2", label="Gradient accumulation", interactive= True ) | |
| def run_optimization_wrapper ( | |
| mask_np_list, | |
| mask_label_list, | |
| image, | |
| opt_flag, | |
| num_tokens, | |
| embedding_learning_rate , | |
| max_emb_train_steps , | |
| diffusion_model_learning_rate , | |
| max_diffusion_train_steps, | |
| train_batch_size, | |
| gradient_accumulation_steps, | |
| ): | |
| try: | |
| run_optimization = partial( | |
| run_main, | |
| mask_np_list=mask_np_list, | |
| mask_label_list=mask_label_list, | |
| image_gt=np.array(image), | |
| num_tokens=int(num_tokens), | |
| embedding_learning_rate = float(embedding_learning_rate), | |
| max_emb_train_steps = min(int(max_emb_train_steps),50), | |
| diffusion_model_learning_rate= float(diffusion_model_learning_rate), | |
| max_diffusion_train_steps = min(int(max_diffusion_train_steps),100), | |
| train_batch_size=int(train_batch_size), | |
| gradient_accumulation_steps=int(gradient_accumulation_steps) | |
| ) | |
| run_optimization() | |
| gr.Info("Optimization Finished! Move to the next step.") | |
| return "Optimization finished! Move to the next step."#,gr.Button("Step 3. Run Editing",interactive = True) | |
| except Exception as e: | |
| print(e) | |
| gr.Error("e") | |
| return "Error: use a smaller batch size or try latter."#,gr.Button("Step 3. Run Editing",interactive = False) | |
| if 1: | |
| with gr.Row(): | |
| with gr.Column(): | |
| canvas_text_edit = gr.Image(value = None, type = "pil", label="Editing results", show_label=True,visible = True) | |
| # canvas_text_edit = gr.Gallery(label = "Edited results") | |
| with gr.Column(): | |
| gr.Markdown("""<p style="text-align: center; font-size: 20px">Editing setting</p>""") | |
| tgt_prompt = gr.Textbox(value="text prompt", label="Editing: Text prompt", interactive= True ) | |
| with gr.Accordion(label="Advanced settings", open=False): | |
| slider2 = gr.Slider(0, 20, step=1, label = 'mask id', visible=False) | |
| guidance_scale = gr.Textbox(value="5", label="Editing: CFG guidance scale", interactive= True ) | |
| num_sampling_steps = gr.Number(value="20", label="Editing: Sampling steps", interactive= True ) | |
| edge_thickness = gr.Number(value="10", label="Editing: Edge thickness", interactive= True ) | |
| strength = gr.Textbox(value="0.5", label="Editing: Mask strength", interactive= True ) | |
| add_button = gr.Button("Step 2. Run Editing",interactive = True) | |
| def run_edit_text_wrapper( | |
| mask_np_list, | |
| mask_label_list, | |
| image, | |
| num_tokens, | |
| guidance_scale, | |
| num_sampling_steps , | |
| strength , | |
| edge_thickness, | |
| tgt_prompt , | |
| tgt_index | |
| ): | |
| run_edit_text = partial( | |
| run_main, | |
| mask_np_list=mask_np_list, | |
| mask_label_list=mask_label_list, | |
| image_gt=np.array(image), | |
| load_trained=True, | |
| text=True, | |
| num_tokens = int(num_tokens_global.value), | |
| guidance_scale = float(guidance_scale), | |
| num_sampling_steps = int(num_sampling_steps), | |
| strength = float(strength), | |
| edge_thickness = int(edge_thickness), | |
| num_imgs = 1, | |
| tgt_prompt = tgt_prompt, | |
| tgt_index = int(tgt_index) | |
| ) | |
| run_edit_text() | |
| gr.Info('Image editing completed.') | |
| return load_pil_img() | |
| def run_total_wrapper(mask_np_list, mask_label_list, image_loaded, opt_flag, num_tokens, embedding_learning_rate, max_emb_train_steps, diffusion_model_learning_rate, max_diffusion_train_steps, train_batch_size, gradient_accumulation_steps, num_tokens_global, guidance_scale, num_sampling_steps, strength, edge_thickness, tgt_prompt, slider2): | |
| result_info = run_optimization_wrapper(mask_np_list, mask_label_list, image_loaded, opt_flag, num_tokens, embedding_learning_rate, max_emb_train_steps, diffusion_model_learning_rate, max_diffusion_train_steps, train_batch_size, gradient_accumulation_steps) | |
| canvas_text_edit = run_edit_text_wrapper(mask_np_list, mask_label_list, image_loaded, num_tokens_global, guidance_scale, num_sampling_steps, strength, edge_thickness, tgt_prompt, slider2) | |
| return result_info, canvas_text_edit | |
| add_button.click( | |
| run_total_wrapper, | |
| inputs=[ | |
| mask_np_list, | |
| mask_label_list, | |
| image_loaded, | |
| opt_flag, | |
| num_tokens, | |
| embedding_learning_rate, | |
| max_emb_train_steps, | |
| diffusion_model_learning_rate, | |
| max_diffusion_train_steps, | |
| train_batch_size, | |
| gradient_accumulation_steps, | |
| num_tokens_global, | |
| guidance_scale, | |
| num_sampling_steps, | |
| strength, | |
| edge_thickness, | |
| tgt_prompt, | |
| slider2 | |
| ], | |
| outputs=[result_info, canvas_text_edit], | |
| ) | |
| canvas.upload(image_change, inputs=[], outputs=[slider]) | |
| slider.release(slider_release, | |
| inputs = [slider, image_loaded, mask_np_list_updated, mask_label_list], | |
| outputs= [canvas, label,tgt_prompt]) | |
| slider.change( | |
| lambda x: x, | |
| inputs=[slider], | |
| outputs=[slider2] | |
| ) | |
| segment_button.click(run_segmentation_wrapper, | |
| [canvas] , | |
| [image_loaded, segmentation, mask_np_list, mask_label_list, canvas, slider, slider2, result_info0] ) | |
| demo.queue().launch(debug=True) | |