Spaces:
Running
Running
| import webdataset as wds | |
| from pathlib import Path | |
| import pandas as pd | |
| import numpy as np | |
| from PIL import Image | |
| import torch | |
| import torchvision.transforms as transforms | |
| from torch.utils.data import Dataset, DataLoader | |
| from utils.image_processing import CenterCrop | |
| from tqdm import tqdm | |
| import os | |
| tqdm.pandas() | |
| print("Loading dinov2") | |
| augmentation_dinov2 = transforms.Compose( | |
| [ | |
| CenterCrop(ratio="1:1"), | |
| transforms.Resize(336, interpolation=transforms.InterpolationMode.BICUBIC), | |
| transforms.ToTensor(), | |
| transforms.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)), | |
| ] | |
| ) | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| model = torch.hub.load("facebookresearch/dinov2", "dinov2_vitl14_reg") | |
| model.eval() | |
| model.to(device) | |
| print(f"Model loaded on {device}") | |
| class YFCCDataset(Dataset): | |
| def __init__(self, csv_path, images_root): | |
| self.df = pd.read_csv(csv_path, sep="\t") | |
| self.df = self.df[self.df["latitude"].notna() & self.df["longitude"].notna()] | |
| self.images_root = Path(images_root) | |
| # Create image paths and check existence | |
| print("Checking image existence...") | |
| self.df["image_path"] = self.df["hash"].progress_apply( | |
| lambda x: self.images_root / x[:3] / x[3:6] / f"{x}.jpg" | |
| ) | |
| def __len__(self): | |
| return len(self.df) | |
| def __getitem__(self, idx): | |
| row = self.df.iloc[idx] | |
| image_path = row["image_path"] | |
| if not image_path.exists(): | |
| print(f"Image {image_path} does not exist") | |
| return None | |
| # Read the JPEG file directly as bytes | |
| with open(image_path, "rb") as f: | |
| jpg_data = f.read() | |
| image = Image.open(image_path).convert("RGB") | |
| image = augmentation_dinov2(image) | |
| # Convert metadata to dict and ensure all values are JSON serializable | |
| metadata = row.to_dict() | |
| del metadata["image_path"] | |
| return { | |
| "image": image, | |
| "jpg_data": jpg_data, | |
| "photo_id": str(row["photo_id"]), | |
| "metadata": metadata, | |
| } | |
| def custom_collate(batch): | |
| """ | |
| Custom collate function to handle dictionary items from the dataset | |
| """ | |
| return { | |
| "image": torch.stack([item["image"] for item in batch if item is not None]), | |
| "jpg_data": [item["jpg_data"] for item in batch if item is not None], | |
| "photo_id": [item["photo_id"] for item in batch if item is not None], | |
| "metadata": [item["metadata"] for item in batch if item is not None], | |
| } | |
| def process_batch(batch, model, device): | |
| images = batch["image"].to(device) # No need to stack, already stacked in collate | |
| with torch.no_grad(): | |
| embeddings = model(images).cpu().numpy() | |
| samples = [] | |
| for i in range(len(batch["photo_id"])): | |
| sample = { | |
| "__key__": batch["photo_id"][i], | |
| "jpg": batch["jpg_data"][i], | |
| "dinov2_vitl14_registers.npy": embeddings[i], | |
| "json": batch["metadata"][i], | |
| } | |
| samples.append(sample) | |
| return samples | |
| def main( | |
| src_csv, | |
| src_images, | |
| dest_folder, | |
| num_samples_per_tar=10000, | |
| job_offset=0, | |
| batch_size=32, | |
| ): | |
| print(f"Loading dataset") | |
| dataset = YFCCDataset(src_csv, src_images) | |
| dataloader = DataLoader( | |
| dataset, | |
| batch_size=batch_size, | |
| shuffle=False, | |
| num_workers=8, | |
| pin_memory=True, | |
| collate_fn=custom_collate, # Add the custom collate function | |
| ) | |
| print(f"Processing job {job_offset} with {len(dataset)} samples") | |
| with wds.ShardWriter( | |
| str(Path(dest_folder) / "%04d.tar"), | |
| maxcount=num_samples_per_tar, | |
| start_shard=10 * job_offset, | |
| ) as sink: | |
| for batch in tqdm(dataloader): | |
| samples = process_batch(batch, model, device) | |
| for sample in samples: | |
| sink.write(sample) | |
| if __name__ == "__main__": | |
| import argparse | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument("--src_csv_dir", help="pixel_input_folder") | |
| parser.add_argument("--src_images_dir", help="path to source images") | |
| parser.add_argument("--dest", help="path to destination web") | |
| parser.add_argument( | |
| "--num_samples_per_tar", | |
| help="number of samples per tar", | |
| type=int, | |
| default=10000, | |
| ) | |
| parser.add_argument("--job_offset", help="job offset", type=int, default=0) | |
| parser.add_argument("--batch_size", help="batch size", type=int, default=256) | |
| args = parser.parse_args() | |
| dest = Path(args.dest) | |
| dest.mkdir(exist_ok=True, parents=True) | |
| main( | |
| Path(args.src_csv_dir) / f"{str(args.job_offset).zfill(3)}.csv", | |
| args.src_images_dir, | |
| args.dest, | |
| args.num_samples_per_tar, | |
| args.job_offset, | |
| args.batch_size, | |
| ) | |