Spaces:
Sleeping
Sleeping
| import io | |
| import os | |
| import glob | |
| from pathlib import Path | |
| import pickle | |
| import random | |
| import time | |
| import cv2 | |
| import torch | |
| import numpy as np | |
| import pandas as pd | |
| import torchvision.transforms as transforms | |
| from PIL import Image, ImageOps, ImageCms | |
| from decord import VideoReader | |
| from torch.utils.data.dataset import Dataset | |
| from controlnet_aux import CannyDetector, HEDdetector | |
| import torch.nn.functional as F | |
| from helpers import generate_1x_sequence, generate_2x_sequence, generate_large_blur_sequence, generate_test_case | |
| def unpack_mm_params(p): | |
| if isinstance(p, (tuple, list)): | |
| return p[0], p[1] | |
| elif isinstance(p, (int, float)): | |
| return p, p | |
| raise Exception(f'Unknown input parameter type.\nParameter: {p}.\nType: {type(p)}') | |
| def resize_for_crop(image, min_h, min_w): | |
| img_h, img_w = image.shape[-2:] | |
| if img_h >= min_h and img_w >= min_w: | |
| coef = min(min_h / img_h, min_w / img_w) | |
| elif img_h <= min_h and img_w <=min_w: | |
| coef = max(min_h / img_h, min_w / img_w) | |
| else: | |
| coef = min_h / img_h if min_h > img_h else min_w / img_w | |
| out_h, out_w = int(img_h * coef), int(img_w * coef) | |
| resized_image = transforms.functional.resize(image, (out_h, out_w), antialias=True) | |
| return resized_image | |
| class BaseClass(Dataset): | |
| def __init__( | |
| self, | |
| data_dir, | |
| output_dir, | |
| image_size=(320, 512), | |
| hflip_p=0.5, | |
| controlnet_type='canny', | |
| split='train', | |
| *args, | |
| **kwargs | |
| ): | |
| self.split = split | |
| self.height, self.width = unpack_mm_params(image_size) | |
| self.data_dir = data_dir | |
| self.output_dir = output_dir | |
| self.hflip_p = hflip_p | |
| self.image_size = image_size | |
| self.length = 0 | |
| def __len__(self): | |
| return self.length | |
| def load_frames(self, frames): | |
| # frames: numpy array of shape (N, H, W, C), 0–255 | |
| # → tensor of shape (N, C, H, W) as float | |
| pixel_values = torch.from_numpy(frames).permute(0, 3, 1, 2).contiguous().float() | |
| # normalize to [-1, 1] | |
| pixel_values = pixel_values / 127.5 - 1.0 | |
| # resize to (self.height, self.width) | |
| pixel_values = F.interpolate( | |
| pixel_values, | |
| size=(self.height, self.width), | |
| mode="bilinear", | |
| align_corners=False | |
| ) | |
| return pixel_values | |
| def get_batch(self, idx): | |
| raise Exception('Get batch method is not realized.') | |
| def __getitem__(self, idx): | |
| while True: | |
| try: | |
| video, caption, motion_blur = self.get_batch(idx) | |
| break | |
| except Exception as e: | |
| print(e) | |
| idx = random.randint(0, self.length - 1) | |
| video, = [ | |
| resize_for_crop(x, self.height, self.width) for x in [video] | |
| ] | |
| video, = [ | |
| transforms.functional.center_crop(x, (self.height, self.width)) for x in [video] | |
| ] | |
| data = { | |
| 'video': video, | |
| 'caption': caption, | |
| } | |
| return data | |
| def load_as_srgb(path): | |
| img = Image.open(path) | |
| img = ImageOps.exif_transpose(img) | |
| if 'icc_profile' in img.info: | |
| icc = img.info['icc_profile'] | |
| src_profile = ImageCms.ImageCmsProfile(io.BytesIO(icc)) | |
| dst_profile = ImageCms.createProfile("sRGB") | |
| img = ImageCms.profileToProfile(img, src_profile, dst_profile, outputMode='RGB') | |
| else: | |
| img = img.convert("RGB") # Assume sRGB | |
| return img | |
| class GoProMotionBlurDataset(BaseClass): #7 frame go pro dataset | |
| def __init__(self, | |
| *args, **kwargs): | |
| super().__init__(*args, **kwargs) | |
| # Set blur and sharp directories based on split | |
| if self.split == 'train': | |
| self.blur_root = os.path.join(self.data_dir, 'train', 'blur') | |
| self.sharp_root = os.path.join(self.data_dir, 'train', 'sharp') | |
| elif self.split in ['val', 'test']: | |
| self.blur_root = os.path.join(self.data_dir, 'test', 'blur') | |
| self.sharp_root = os.path.join(self.data_dir, 'test', 'sharp') | |
| else: | |
| raise ValueError(f"Unsupported split: {self.split}") | |
| # Collect all blurred image paths | |
| pattern = os.path.join(self.blur_root, '*', '*.png') | |
| self.blur_paths = sorted(glob.glob(pattern)) | |
| if self.split == 'val': | |
| # Optional: limit validation subset | |
| self.blur_paths = self.blur_paths[:5] | |
| filtered_blur_paths = [] | |
| for path in self.blur_paths: | |
| output_deblurred_dir = os.path.join(self.output_dir, "deblurred") | |
| full_output_path = Path(output_deblurred_dir, *path.split('/')[-2:]).with_suffix(".mp4") | |
| if not os.path.exists(full_output_path): | |
| filtered_blur_paths.append(path) | |
| # Window and padding parameters | |
| self.window_size = 7 # original number of sharp frames | |
| self.pad = 2 # number of times to repeat last frame | |
| self.output_length = self.window_size + self.pad | |
| self.half_window = self.window_size // 2 | |
| self.length = len(self.blur_paths) | |
| # Normalized input interval: always [-0.5, 0.5] | |
| self.input_interval = torch.tensor([[-0.5, 0.5]], dtype=torch.float) | |
| # Precompute normalized output intervals: first for window_size frames, then pad duplicates | |
| step = 1.0 / (self.window_size - 1) | |
| # intervals for the original 7 frames | |
| window_intervals = [] | |
| for i in range(self.window_size): | |
| start = -0.5 + i * step | |
| if i < self.window_size - 1: | |
| end = -0.5 + (i + 1) * step | |
| else: | |
| end = 0.5 | |
| window_intervals.append([start, end]) | |
| # append the last interval pad times | |
| intervals = window_intervals + [window_intervals[-1]] * self.pad | |
| self.output_interval = torch.tensor(intervals, dtype=torch.float) | |
| def __len__(self): | |
| return self.length | |
| def __getitem__(self, idx): | |
| # Path to the blurred (center) frame | |
| blur_path = self.blur_paths[idx] | |
| seq_name = os.path.basename(os.path.dirname(blur_path)) | |
| frame_name = os.path.basename(blur_path) | |
| center_idx = int(os.path.splitext(frame_name)[0]) | |
| # Compute sharp frame range [center-half, center+half] | |
| start_idx = center_idx - self.half_window | |
| end_idx = center_idx + self.half_window | |
| # Load sharp frames | |
| sharp_dir = os.path.join(self.sharp_root, seq_name) | |
| frames = [] | |
| for i in range(start_idx, end_idx + 1): | |
| sharp_filename = f"{i:06d}.png" | |
| sharp_path = os.path.join(sharp_dir, sharp_filename) | |
| img = Image.open(sharp_path).convert('RGB') | |
| frames.append(img) | |
| # Repeat last sharp frame so total frames == output_length | |
| while len(frames) < self.output_length: | |
| frames.append(frames[-1]) | |
| # Load blurred image | |
| blur_img = Image.open(blur_path).convert('RGB') | |
| # Convert to pixel values via BaseClass loader | |
| video = self.load_frames(np.array(frames)) # shape: (output_length, H, W, C) | |
| blur_input = self.load_frames(np.expand_dims(np.array(blur_img), 0)) # shape: (1, H, W, C) | |
| end_time = time.time() | |
| data = { | |
| 'file_name': os.path.join(seq_name, frame_name), | |
| 'blur_img': blur_input, | |
| 'video': video, | |
| "caption": "", | |
| 'motion_blur_amount': torch.tensor(self.half_window, dtype=torch.long), | |
| 'input_interval': self.input_interval, | |
| 'output_interval': self.output_interval, | |
| "num_frames": self.window_size, | |
| "mode": "1x", | |
| } | |
| return data | |
| class OutsidePhotosDataset(BaseClass): | |
| def __init__(self, *args, **kwargs): | |
| super().__init__(*args, **kwargs) | |
| self.image_paths = sorted(glob.glob(os.path.join(self.data_dir, '**', '*.*'), recursive=True)) | |
| INTERVALS = [ | |
| {"in_start": 0, "in_end": 16, "out_start": 0, "out_end": 16, "center": 8, "window_size": 16, "mode": "1x", "fps": 240}, | |
| {"in_start": 4, "in_end": 12, "out_start": 0, "out_end": 16, "center": 8, "window_size": 16, "mode": "2x", "fps": 240},] | |
| #other modes commented out for faster processing | |
| #{"in_start": 0, "in_end": 4, "out_start": 0, "out_end": 4, "center": 2, "window_size": 4, "mode": "1x", "fps": 240}, | |
| #{"in_start": 0, "in_end": 8, "out_start": 0, "out_end": 8, "center": 4, "window_size": 8, "mode": "1x", "fps": 240}, | |
| #{"in_start": 0, "in_end": 12, "out_start": 0, "out_end": 12, "center": 6, "window_size": 12, "mode": "1x", "fps": 240}, | |
| #{"in_start": 0, "in_end": 32, "out_start": 0, "out_end": 32, "center": 12, "window_size": 32, "mode": "lb", "fps": 120} | |
| #{"in_start": 0, "in_end": 48, "out_start": 0, "out_end": 48, "center": 24, "window_size": 48, "mode": "lb", "fps": 80} | |
| self.cleaned_intervals = [] | |
| for image_path in self.image_paths: | |
| for interval in INTERVALS: | |
| #create a copy of the interval dictionary | |
| i = interval.copy() | |
| #add the image path to the interval dictionary | |
| i['video_name'] = image_path | |
| video_name = i['video_name'] | |
| mode = i['mode'] | |
| vid_name_w_extension = os.path.relpath(video_name, self.data_dir).split('.')[0] # "frame_00000" | |
| output_name = ( | |
| f"{vid_name_w_extension}_{mode}.mp4" | |
| ) | |
| full_output_path = os.path.join("/datasets/sai/gencam/cogvideox/training/cogvideox-outsidephotos/deblurred", output_name) #THIS IS A HACK - YOU NEED TO UPDATE THIS TO YOUR OUTPUT DIRECTORY | |
| # Keep only if output doesn't exist | |
| if not os.path.exists(full_output_path): | |
| self.cleaned_intervals.append(i) | |
| self.length = len(self.cleaned_intervals) | |
| def __len__(self): | |
| return self.length | |
| def __getitem__(self, idx): | |
| interval = self.cleaned_intervals[idx] | |
| in_start = interval['in_start'] | |
| in_end = interval['in_end'] | |
| out_start = interval['out_start'] | |
| out_end = interval['out_end'] | |
| center = interval['center'] | |
| window = interval['window_size'] | |
| mode = interval['mode'] | |
| fps = interval['fps'] | |
| image_path = interval['video_name'] | |
| blur_img_original = load_as_srgb(image_path) | |
| H,W = blur_img_original.size | |
| frame_paths = [] | |
| frame_paths = ["../assets/dummy_image.png" for _ in range(window)] #any random path replicated | |
| # generate test case | |
| _, seq_frames, inp_int, out_int, high_fps_video, num_frames = generate_test_case( | |
| frame_paths=frame_paths, window_max=window, in_start=in_start, in_end=in_end, out_start=out_start,out_end=out_end, center=center, mode=mode, fps=fps | |
| ) | |
| file_name = image_path | |
| # Get base directory and frame prefix | |
| relative_file_name = os.path.relpath(file_name, self.data_dir) | |
| base_dir = os.path.dirname(relative_file_name) | |
| frame_stem = os.path.splitext(os.path.basename(file_name))[0] # "frame_00000" | |
| # Build new filename | |
| new_filename = ( | |
| f"{frame_stem}_{mode}.png" | |
| ) | |
| blur_img =blur_img_original.resize((self.image_size[1], self.image_size[0])) #cause pil is width, height | |
| # Final path | |
| relative_file_name = os.path.join(base_dir, new_filename) | |
| blur_input = self.load_frames(np.expand_dims(blur_img, 0).copy()) | |
| # seq_frames is list of frames; stack along time dim | |
| video = self.load_frames(np.stack(seq_frames, axis=0)) | |
| data = { | |
| 'file_name': relative_file_name, | |
| "original_size": (H, W), | |
| 'blur_img': blur_input, | |
| 'video': video, | |
| 'caption': "", | |
| 'input_interval': inp_int, | |
| 'output_interval': out_int, | |
| "num_frames": num_frames, | |
| } | |
| return data | |
| class FullMotionBlurDataset(BaseClass): | |
| """ | |
| A dataset that randomly selects among 1×, 2×, or large-blur modes per sample. | |
| Uses category-specific <split>_list.txt files under each subfolder of FullDataset to assemble sequences. | |
| In 'test' split, it instead loads precomputed intervals from intervals_test.pkl and uses generate_test_case. | |
| """ | |
| def __init__(self, *args, **kwargs): | |
| super().__init__(*args, **kwargs) | |
| self.seq_dirs = [] | |
| # TEST split: load fixed intervals early | |
| if self.split == 'test': | |
| pkl_path = os.path.join(self.data_dir, 'intervals_test.pkl') | |
| with open(pkl_path, 'rb') as f: | |
| self.test_intervals = pickle.load(f) | |
| assert self.test_intervals, f"No test intervals found in {pkl_path}" | |
| cleaned_intervals = [] | |
| for interval in self.test_intervals: | |
| # Extract interval values | |
| in_start = interval['in_start'] | |
| in_end = interval['in_end'] | |
| out_start = interval['out_start'] | |
| out_end = interval['out_end'] | |
| center = interval['center'] | |
| window = interval['window_size'] | |
| mode = interval['mode'] | |
| fps = interval['fps'] # e.g. "lower_fps_frames/720p_240fps_1/frame_00247.png" | |
| category, seq = interval['video_name'].split('/')#.split('/') | |
| seq_dir = os.path.join(self.data_dir, category, 'lower_fps_frames', seq) | |
| frame_paths = sorted(glob.glob(os.path.join(seq_dir, '*.png'))) | |
| rel_path = os.path.relpath(frame_paths[center], self.data_dir) | |
| rel_path = os.path.splitext(rel_path)[0] # remove the file extension | |
| output_name = ( | |
| f"{rel_path}_" | |
| f"in{in_start:04d}_ie{in_end:04d}_" | |
| f"os{out_start:04d}_oe{out_end:04d}_" | |
| f"ctr{center:04d}_win{window:04d}_" | |
| f"fps{fps:04d}_{mode}.mp4" | |
| ) | |
| output_deblurred_dir = os.path.join(self.output_dir, "deblurred") | |
| full_output_path = os.path.join(output_deblurred_dir, output_name) | |
| # Keep only if output doesn't exist | |
| if not os.path.exists(full_output_path): | |
| cleaned_intervals.append(interval) | |
| print("Len of test intervals after cleaning: ", len(cleaned_intervals)) | |
| print("Len of test intervals before cleaning: ", len(self.test_intervals)) | |
| self.test_intervals = cleaned_intervals | |
| # TRAIN/VAL: build seq_dirs from each category's list or fallback | |
| list_file = 'train_list.txt' if self.split == 'train' else 'test_list.txt' | |
| for category in sorted(os.listdir(self.data_dir)): | |
| cat_dir = os.path.join(self.data_dir, category) | |
| if not os.path.isdir(cat_dir): | |
| continue | |
| list_path = os.path.join(cat_dir, list_file) | |
| if os.path.isfile(list_path): | |
| with open(list_path, 'r') as f: | |
| for line in f: | |
| rel = line.strip() | |
| if not rel: | |
| continue | |
| seq_dir = os.path.join(self.data_dir, rel) | |
| if os.path.isdir(seq_dir): | |
| self.seq_dirs.append(seq_dir) | |
| else: | |
| fps_root = os.path.join(cat_dir, 'lower_fps_frames') | |
| if os.path.isdir(fps_root): | |
| for seq in sorted(os.listdir(fps_root)): | |
| seq_path = os.path.join(fps_root, seq) | |
| if os.path.isdir(seq_path): | |
| self.seq_dirs.append(seq_path) | |
| if self.split == 'val': | |
| self.seq_dirs = self.seq_dirs[:5] | |
| if self.split == 'train': | |
| self.seq_dirs *= 10 | |
| assert self.seq_dirs, \ | |
| f"No sequences found for split '{self.split}' in {self.data_dir}" | |
| def __len__(self): | |
| return len(self.test_intervals) if self.split == 'test' else len(self.seq_dirs) | |
| def __getitem__(self, idx): | |
| # Prepare base items | |
| if self.split == 'test': | |
| interval = self.test_intervals[idx] | |
| category, seq = interval['video_name'].split('/') | |
| seq_dir = os.path.join(self.data_dir, category, 'lower_fps_frames', seq) | |
| frame_paths = sorted(glob.glob(os.path.join(seq_dir, '*.png'))) | |
| in_start = interval['in_start'] | |
| in_end = interval['in_end'] | |
| out_start = interval['out_start'] | |
| out_end = interval['out_end'] | |
| center = interval['center'] | |
| window = interval['window_size'] | |
| mode = interval['mode'] | |
| fps = interval['fps'] | |
| # generate test case | |
| blur_img, seq_frames, inp_int, out_int, high_fps_video, num_frames = generate_test_case( | |
| frame_paths=frame_paths, window_max=window, in_start=in_start, in_end=in_end, out_start=out_start,out_end=out_end, center=center, mode=mode, fps=fps | |
| ) | |
| file_name = frame_paths[center] | |
| else: | |
| seq_dir = self.seq_dirs[idx] | |
| frame_paths = sorted(glob.glob(os.path.join(seq_dir, '*.png'))) | |
| mode = random.choice(['1x', '2x', 'large_blur']) | |
| if mode == '1x' or len(frame_paths) < 50: | |
| base_rate = random.choice([1, 2]) | |
| blur_img, seq_frames, inp_int, out_int, _ = generate_1x_sequence( | |
| frame_paths, window_max=16, output_len=17, base_rate=base_rate | |
| ) | |
| elif mode == '2x': | |
| base_rate = random.choice([1, 2]) | |
| blur_img, seq_frames, inp_int, out_int, _ = generate_2x_sequence( | |
| frame_paths, window_max=16, output_len=17, base_rate=base_rate | |
| ) | |
| else: | |
| max_base = min((len(frame_paths) - 1) // 17, 3) | |
| base_rate = random.randint(1, max_base) | |
| blur_img, seq_frames, inp_int, out_int, _ = generate_large_blur_sequence( | |
| frame_paths, window_max=16, output_len=17, base_rate=base_rate | |
| ) | |
| file_name = frame_paths[0] | |
| num_frames = 16 | |
| # blur_img is a single frame; wrap in batch dim | |
| blur_input = self.load_frames(np.expand_dims(blur_img, 0)) | |
| # seq_frames is list of frames; stack along time dim | |
| video = self.load_frames(np.stack(seq_frames, axis=0)) | |
| relative_file_name = os.path.relpath(file_name, self.data_dir) | |
| if self.split == 'test': | |
| # Get base directory and frame prefix | |
| base_dir = os.path.dirname(relative_file_name) | |
| frame_stem = os.path.splitext(os.path.basename(relative_file_name))[0] # "frame_00000" | |
| # Build new filename | |
| new_filename = ( | |
| f"{frame_stem}_" | |
| f"in{in_start:04d}_ie{in_end:04d}_" | |
| f"os{out_start:04d}_oe{out_end:04d}_" | |
| f"ctr{center:04d}_win{window:04d}_" | |
| f"fps{fps:04d}_{mode}.png" | |
| ) | |
| # Final path | |
| relative_file_name = os.path.join(base_dir, new_filename) | |
| data = { | |
| 'file_name': relative_file_name, | |
| 'blur_img': blur_input, | |
| 'num_frames': num_frames, | |
| 'video': video, | |
| 'caption': "", | |
| 'mode': mode, | |
| 'input_interval': inp_int, | |
| 'output_interval': out_int, | |
| } | |
| if self.split == 'test': | |
| high_fps_video = self.load_frames(np.stack(high_fps_video, axis=0)) | |
| data['high_fps_video'] = high_fps_video | |
| return data | |
| class GoPro2xMotionBlurDataset(BaseClass): | |
| def __init__(self, | |
| *args, **kwargs): | |
| super().__init__(*args, **kwargs) | |
| # Set blur and sharp directories based on split | |
| if self.split == 'train': | |
| self.blur_root = os.path.join(self.data_dir, 'train', 'blur') | |
| self.sharp_root = os.path.join(self.data_dir, 'train', 'sharp') | |
| elif self.split in ['val', 'test']: | |
| self.blur_root = os.path.join(self.data_dir, 'test', 'blur') | |
| self.sharp_root = os.path.join(self.data_dir, 'test', 'sharp') | |
| else: | |
| raise ValueError(f"Unsupported split: {self.split}") | |
| # Collect all blurred image paths | |
| pattern = os.path.join(self.blur_root, '*', '*.png') | |
| def get_sharp_paths(blur_paths): | |
| sharp_paths = [] | |
| for blur_path in blur_paths: | |
| base_dir = blur_path.replace('/blur/', '/sharp/') | |
| frame_num = int(os.path.basename(blur_path).split('.')[0]) | |
| dir_path = os.path.dirname(base_dir) | |
| sequence = [ | |
| os.path.join(dir_path, f"{frame_num + offset:06d}.png") | |
| for offset in range(-6, 7) | |
| ] | |
| if all(os.path.exists(path) for path in sequence): | |
| sharp_paths.append(sequence) | |
| return sharp_paths | |
| self.blur_paths = sorted(glob.glob(pattern)) | |
| filtered_blur_paths = [] | |
| for path in self.blur_paths: | |
| output_deblurred_dir = os.path.join(self.output_dir, "deblurred") | |
| full_output_path = Path(output_deblurred_dir, *path.split('/')[-2:]).with_suffix(".mp4") | |
| if not os.path.exists(full_output_path): | |
| filtered_blur_paths.append(path) | |
| self.blur_paths = filtered_blur_paths | |
| self.sharp_paths = get_sharp_paths(self.blur_paths) | |
| if self.split == 'val': | |
| # Optional: limit validation subset | |
| self.sharp_paths = self.sharp_paths[:5] | |
| self.length = len(self.sharp_paths) | |
| def __len__(self): | |
| return self.length | |
| def __getitem__(self, idx): | |
| # Path to the blurred (center) frame | |
| sharp_path = self.sharp_paths[idx] | |
| # Load sharp frames | |
| blur_img, seq_frames, inp_int, out_int, high_fps_video, num_frames = generate_test_case( | |
| frame_paths=sharp_path, window_max=13, in_start=3, in_end=10, out_start=0,out_end=13, center=6, mode="2x", fps=240 | |
| ) | |
| # Convert to pixel values via BaseClass loader | |
| video = self.load_frames(np.array(seq_frames)) # shape: (output_length, H, W, C) | |
| blur_input = self.load_frames(np.expand_dims(np.array(blur_img), 0)) # shape: (1, H, W, C) | |
| last_two_parts_of_path = os.path.join(*sharp_path[6].split(os.sep)[-2:]) | |
| #print(f"Time taken to load and process data: {end_time - start_time:.2f} seconds") | |
| data = { | |
| 'file_name': last_two_parts_of_path, | |
| 'blur_img': blur_input, | |
| 'video': video, | |
| "caption": "", | |
| 'input_interval': inp_int, | |
| 'output_interval': out_int, | |
| "num_frames": num_frames, | |
| "mode": "2x", | |
| } | |
| return data | |
| class BAISTDataset(BaseClass): | |
| def __init__(self, *args, **kwargs): | |
| super().__init__(*args, **kwargs) | |
| test_folders = { | |
| "gWA_sBM_c01_d26_mWA0_ch06_cropped_32X": None, | |
| "gBR_sBM_c01_d05_mBR0_ch01_cropped_32X": None, | |
| "gMH_sBM_c01_d22_mMH0_ch04_cropped_32X": None, | |
| "gHO_sBM_c01_d20_mHO0_ch05_cropped_32X": None, | |
| "gMH_sBM_c01_d22_mMH0_ch08_cropped_32X": None, | |
| "gWA_sBM_c01_d26_mWA0_ch02_cropped_32X": None, | |
| "gJS_sBM_c01_d02_mJS0_ch08_cropped_32X": None, | |
| "gHO_sBM_c01_d20_mHO0_ch07_cropped_32X": None, | |
| "gHO_sBM_c01_d20_mHO0_ch06_cropped_32X": None, | |
| "gBR_sBM_c01_d05_mBR0_ch03_cropped_32X": None, | |
| "gBR_sBM_c01_d05_mBR0_ch05_cropped_32X": None, | |
| "gHO_sBM_c01_d20_mHO0_ch02_cropped_32X": None, | |
| "gHO_sBM_c01_d20_mHO0_ch03_cropped_32X": None, | |
| "gHO_sBM_c01_d20_mHO0_ch09_cropped_32X": None, | |
| "gMH_sBM_c01_d22_mMH0_ch10_cropped_32X": None, | |
| "gWA_sBM_c01_d26_mWA0_ch10_cropped_32X": None, | |
| "gBR_sBM_c01_d05_mBR0_ch06_cropped_32X": None, | |
| "gHO_sBM_c01_d20_mHO0_ch08_cropped_32X": None, | |
| "gMH_sBM_c01_d22_mMH0_ch06_cropped_32X": None, | |
| "gHO_sBM_c01_d20_mHO0_ch10_cropped_32X": None, | |
| "gMH_sBM_c01_d22_mMH0_ch09_cropped_32X": None, | |
| "gMH_sBM_c01_d22_mMH0_ch02_cropped_32X": None, | |
| "gBR_sBM_c01_d05_mBR0_ch04_cropped_32X": None, | |
| "gPO_sBM_c01_d10_mPO0_ch09_cropped_32X": None, | |
| "gMH_sBM_c01_d22_mMH0_ch01_cropped_32X": None, | |
| "gMH_sBM_c01_d22_mMH0_ch07_cropped_32X": None, | |
| "gMH_sBM_c01_d22_mMH0_ch03_cropped_32X": None, | |
| "gHO_sBM_c01_d20_mHO0_ch04_cropped_32X": None, | |
| "gBR_sBM_c01_d05_mBR0_ch02_cropped_32X": None, | |
| "gHO_sBM_c01_d20_mHO0_ch01_cropped_32X": None, | |
| "gMH_sBM_c01_d22_mMH0_ch05_cropped_32X": None, | |
| "gPO_sBM_c01_d10_mPO0_ch10_cropped_32X": None, | |
| } | |
| def collect_blur_images(root_dir, allowed_folders, skip_start=40, skip_end=40): | |
| blur_image_paths = [] | |
| for dirpath, dirnames, filenames in os.walk(root_dir): | |
| if os.path.basename(dirpath) == "blur": | |
| parent_folder = os.path.basename(os.path.dirname(dirpath)) | |
| if (self.split in ["test", "val"] and parent_folder in test_folders) or (self.split in "train" and parent_folder not in test_folders): | |
| # Filter and sort valid image filenames | |
| valid_files = [ | |
| f for f in filenames | |
| if f.lower().endswith(('.png', '.jpg', '.jpeg', '.bmp', '.tiff')) and os.path.splitext(f)[0].isdigit() | |
| ] | |
| valid_files.sort(key=lambda x: int(os.path.splitext(x)[0])) | |
| # Skip first and last N files | |
| middle_files = valid_files[skip_start:len(valid_files) - skip_end] | |
| for f in middle_files: | |
| from pathlib import Path | |
| full_path = Path(os.path.join(dirpath, f)) | |
| output_deblurred_dir = os.path.join(self.output_dir, "deblurred") | |
| full_output_path = Path(output_deblurred_dir, *full_path.parts[-3:]).with_suffix(".mp4") | |
| if not os.path.exists(full_output_path) or self.split in ["train", "val"]: | |
| blur_image_paths.append(os.path.join(dirpath, f)) | |
| return blur_image_paths | |
| self.image_paths = collect_blur_images(self.data_dir, test_folders) | |
| #if bbx path does not exist, remove the image path | |
| self.image_paths = [path for path in self.image_paths if os.path.exists(path.replace("blur", "blur_anno").replace(".png", ".pkl"))] | |
| filtered_image_paths = [] | |
| for blur_path in self.image_paths: | |
| base_dir = blur_path.replace('/blur/', '/sharp/').replace('.png', '') | |
| sharp_paths = [f"{base_dir}_{i:03d}.png" for i in range(7)] | |
| if all(os.path.exists(p) for p in sharp_paths): | |
| filtered_image_paths.append(blur_path) | |
| self.image_paths = filtered_image_paths | |
| if self.split == 'val': | |
| # Optional: limit validation subset | |
| self.image_paths = self.image_paths[:4] | |
| self.length = len(self.image_paths) | |
| def __len__(self): | |
| return self.length | |
| def __getitem__(self, idx): | |
| image_path = self.image_paths[idx] | |
| blur_img_original = load_as_srgb(image_path) | |
| bbx_path = image_path.replace("blur", "blur_anno").replace(".png", ".pkl") | |
| #load the bbx path | |
| bbx = np.load(bbx_path, allow_pickle=True)['bbox'][0:4] | |
| # Final crop box | |
| #turn crop_box into tupel | |
| W,H = blur_img_original.size | |
| blur_img = blur_img_original.resize((self.image_size[1], self.image_size[0]), resample=Image.BILINEAR) | |
| #cause pil is width, height | |
| blur_np = np.array([blur_img]) | |
| base_dir = os.path.dirname(os.path.dirname(image_path)) # strip /blur | |
| filename = os.path.splitext(os.path.basename(image_path))[0] # '00000000' | |
| sharp_dir = os.path.join(base_dir, "sharp") | |
| frame_paths = [ | |
| os.path.join(sharp_dir, f"{filename}_{i:03d}.png") | |
| for i in range(7) | |
| ] | |
| _, seq_frames, inp_int, out_int, high_fps_video, num_frames = generate_test_case( | |
| frame_paths=frame_paths, window_max=7, in_start=0, in_end=7, out_start=0,out_end=7, center=3, mode="1x", fps=240 | |
| ) | |
| pixel_values = self.load_frames(np.stack(seq_frames, axis=0)) | |
| blur_pixel_values = self.load_frames(blur_np) | |
| relative_file_name = os.path.relpath(image_path, self.data_dir) | |
| out_bbx = bbx.copy() | |
| scale_x = blur_pixel_values.shape[3]/W | |
| scale_y = blur_pixel_values.shape[2]/H | |
| #scale the bbx | |
| out_bbx[0] = int(out_bbx[0] * scale_x) | |
| out_bbx[1] = int(out_bbx[1] * scale_y) | |
| out_bbx[2] = int(out_bbx[2] * scale_x) | |
| out_bbx[3] = int(out_bbx[3] * scale_y) | |
| out_bbx = torch.tensor(out_bbx, dtype=torch.uint32) | |
| #crop image using the bbx | |
| blur_img_npy = np.array(blur_img) | |
| out_bbx_npy = out_bbx.numpy().astype(np.uint32) | |
| blur_img_npy = blur_img_npy[out_bbx_npy[1]:out_bbx_npy[3], out_bbx_npy[0]:out_bbx_npy[2], :] | |
| data = { | |
| 'file_name': relative_file_name, | |
| 'blur_img': blur_pixel_values, | |
| 'video': pixel_values, | |
| 'bbx': out_bbx, | |
| 'caption': "", | |
| 'input_interval': inp_int, | |
| 'output_interval': out_int, | |
| "num_frames": num_frames, | |
| 'mode': "1x", | |
| } | |
| return data | |