import gradio as gr import os import warnings so_path = "models/GroundingDINO/ops/MultiScaleDeformableAttention.cpython-39-x86_64-linux-gnu.so" if not os.path.exists(so_path): os.system("python models/GroundingDINO/ops/setup.py build_ext develop --user") import torchvision.transforms as T from models import build_model import torch import misc as utils import numpy as np import torch.nn.functional as F from torchvision.io import read_video import torchvision.transforms.functional as Func from ruamel.yaml import YAML from easydict import EasyDict from misc import nested_tensor_from_videos_list from torch.cuda.amp import autocast from PIL import Image, ImageDraw import imageio.v3 as iio import cv2 import tempfile import argparse import time from huggingface_hub import hf_hub_download os.environ["TOKENIZERS_PARALLELISM"] = "false" DURATION = 6 CHECKPOINT = "ryt_mevis_swinb.pth" # Transform for video frames transform = T.Compose([ T.Resize(360), T.ToTensor(), T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) ]) # Colormap color_list = utils.colormap() color_list = color_list.astype('uint8').tolist() # Global model variable model = None def load_model_once(config_path, device='cpu'): """Load model once at startup""" global model if model is None: # Create args object for model loading with open(config_path) as f: yaml = YAML(typ='safe', pure=True) config = yaml.load(f) config = {k: v['value'] for k, v in config.items()} args = EasyDict(config) args.device = device model = build_model(args) model.to(device) cache_file = hf_hub_download(repo_id="liangtm/referdino", filename=CHECKPOINT) # cache_file = 'ckpt/' + CHECKPOINT checkpoint = torch.load(cache_file, map_location='cpu') state_dict = checkpoint["model_state_dict"] model.load_state_dict(state_dict, strict=False) model.eval() print("Model loaded successfully!") return model def box_cxcywh_to_xyxy(x): x_c, y_c, w, h = x[:, 0], x[:, 1], x[:, 2], x[:, 3] b = np.stack([ x_c - 0.5 * w, y_c - 0.5 * h, x_c + 0.5 * w, y_c + 0.5 * h ], axis=1) return b def rescale_bboxes(out_bbox, size): img_w, img_h = size b = box_cxcywh_to_xyxy(out_bbox) b = b * np.array([img_w, img_h, img_w, img_h], dtype=np.float32) return b def vis_add_mask(img, mask, color, edge_width=3): origin_img = np.asarray(img.convert('RGB')).copy() color = np.array(color) mask = mask.reshape(mask.shape[0], mask.shape[1]).astype('uint8') mask = mask > 0.5 # Increase the edge width using dilation kernel = np.ones((edge_width, edge_width), np.uint8) mask_dilated = cv2.dilate(mask.astype(np.uint8), kernel, iterations=1).astype(bool) edge_mask = mask_dilated & ~mask origin_img[mask] = origin_img[mask] * 0.5 + color * 0.5 origin_img[edge_mask] = color origin_img = Image.fromarray(origin_img) return origin_img def run_video_inference(input_video, text_prompt, tracking_alpha=0.1, fps=15): """Main inference function for Gradio""" global model model.tracking_alpha = tracking_alpha # Set default values for other parameters show_box = True mask_edge_width = 6 if input_video is None: return None, "Please upload a video file." if not text_prompt or text_prompt.strip() == "": return None, "Please enter a text prompt." # Process text prompt exp = " ".join(text_prompt.lower().split()) # Read video video_frames, _, info = read_video(input_video, end_pts=DURATION, pts_unit='sec') # (T, H, W, C) frame_step = max(round(info['video_fps'] / fps), 1) frames = [] for i in range(0, len(video_frames), frame_step): source_frame = Func.to_pil_image(video_frames[i].permute(2, 0, 1)) frames.append(source_frame) video_len = len(frames) if video_len == 0: return None, "No frames found in the video." frames_ids = [x for x in range(video_len)] imgs = [] for t in frames_ids: img = frames[t] origin_w, origin_h = img.size imgs.append(transform(img)) device = next(model.parameters()).device imgs = torch.stack(imgs, dim=0).to(device) samples = nested_tensor_from_videos_list(imgs[None], size_divisibility=16) img_h, img_w = imgs.shape[-2:] size = torch.as_tensor([int(img_h), int(img_w)]).to(device) target = {"size": size} start_infer = time.time() # Run inference with torch.no_grad(): with autocast(True): outputs = model(samples, [exp], [target]) end_infer = time.time() pred_logits = outputs["pred_logits"][0] # [t, q, k] pred_masks = outputs["pred_masks"][0] # [t, q, h, w] pred_boxes = outputs["pred_boxes"][0] # [t, q, 4] # Select the query index according to pred_logits pred_scores = pred_logits.sigmoid() # [t, q, k] pred_scores = pred_scores.mean(0) # [q, K] max_scores, _ = pred_scores.max(-1) # [q,] _, max_ind = max_scores.max(-1) # [1,] max_inds = max_ind.repeat(video_len) pred_masks = pred_masks[range(video_len), max_inds, ...] # [t, h, w] pred_masks = pred_masks.unsqueeze(0) pred_boxes = pred_boxes[range(video_len), max_inds].cpu().numpy() # [t, 4] # Unpad and resize pred_masks = pred_masks[:, :, :img_h, :img_w].cpu() pred_masks = F.interpolate(pred_masks, size=(origin_h, origin_w), mode='bilinear', align_corners=False) pred_masks = (pred_masks.sigmoid() > 0.5).squeeze(0).cpu().numpy() # Visualization color = np.array([220, 20, 60], dtype=np.uint8) start_save = time.time() save_imgs = [] for t, img in enumerate(frames): # Draw mask img = vis_add_mask(img, pred_masks[t], color, mask_edge_width) draw = ImageDraw.Draw(img) draw_boxes = pred_boxes[t][None] draw_boxes = rescale_bboxes(draw_boxes, (origin_w, origin_h)).tolist() # Draw box if enabled if show_box: xmin, ymin, xmax, ymax = draw_boxes[0] draw.rectangle(((xmin, ymin), (xmax, ymax)), outline=tuple(color), width=5) save_imgs.append(np.asarray(img).copy()) # Save result video with tempfile.NamedTemporaryFile(suffix='.mp4', delete=False) as tmp_file: iio.imwrite(tmp_file.name, save_imgs, fps=fps) result_video_path = tmp_file.name end_save = time.time() status = ( f"Inference Time: {(end_infer - start_infer):.1f}s\n" f"Saving Time: {(end_save - start_save):.1f}s" ) return result_video_path, status def main(): # Configuration config_path = "configs/ytvos_swinb.yaml" # Update this path device = "cuda" if torch.cuda.is_available() else "cpu" # device = "cpu" # Load model at startup print("Loading model...") load_model_once(config_path, device) print(f"Model loaded on device: {device}") # Create Gradio interface with gr.Blocks( title="ReferDINO", css=""" #hero { text-align: center; } #hero h1, #hero h2, #hero h3, #hero p { text-align: center !important; margin: 0.25rem 0; } """ ) as demo: gr.Markdown( """

Referring Video Object Segmentation with ReferDINO

Note that this demo runs on CPU, so the video will be trimmed to ≤6 seconds.

""", elem_id="hero", ) with gr.Row(): with gr.Column(scale=1): # Input components input_video = gr.Video( label="📹 Upload Video", height=300 ) text_prompt = gr.Textbox( label="📝 Text Description", placeholder="Describe the object you want to segment (e.g., 'red car', 'person in blue shirt')", lines=2 ) run_button = gr.Button( "🚀 Run Inference", variant="primary", size="lg" ) tracking_alpha = gr.Slider( label="Momentum", minimum=0.0, maximum=1.0, value=0.1, step=0.05, info="controls the memory updating (lower = longer memory)" ) target_fps = gr.Slider( label="FPS", minimum=1, maximum=30, value=10, step=1, info="controls the FPS (lower = faster processing)" ) with gr.Column(scale=1): output_video = gr.Video( label="🎯 Segmentation Result", height=400 ) status_text = gr.Textbox( label="📊 Status", lines=3, interactive=False ) # Examples gr.Examples( examples=[ ["dogs.mp4", "the dog is drinking water", 0.1, 10], ["dogs.mp4", "the dog is sleeping", 0.1, 10], ], inputs=[input_video, text_prompt, tracking_alpha, target_fps], outputs=[output_video], fn=run_video_inference, cache_examples=False, label="📋 Try these examples:" ) # Event handlers run_button.click( fn=run_video_inference, inputs=[input_video, text_prompt, tracking_alpha, target_fps], outputs=[output_video, status_text], show_progress=True ) return demo if __name__ == "__main__": demo = main() demo.launch( show_api=False, show_error=True )