from torch.utils.data import DataLoader from torch.utils.data.distributed import DistributedSampler from torchvision.transforms.v2 import Compose import os, sys from argparse import ArgumentParser from typing import Union, Tuple, List, Dict parent_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), "..")) sys.path.append(parent_dir) import datasets def calc_bin_center( bins: List[Tuple[float, float]], count_stats: Dict[int, int], ) -> Tuple[List[float], List[int]]: """ Calculate the representative value for each bin based on the count statistics. `bins` may look like: [(0, 0), (1, 1), (2, 3), (4, 6), (7, float('inf'))] `count_stats` may look like: {0: 10, 1: 20, 2: 30, 3: 40, 4: 50, 5: 60, 6: 70, 7: 80, 8: 90, 9: 100} In this example, for bin (2, 3), we have 30 samples of 2 and 40 samples of 3 that fall into this bin. The representative value for this bin is (30 * 2 + 40 * 3) / (30 + 40) = 2.6. The returned list will have the same length as `bins`, and each element is the representative value for the corresponding bin. """ bin_counts = [0] * len(bins) bin_sums = [0] * len(bins) for k, v in count_stats.items(): for i, (start, end) in enumerate(bins): if start <= int(k) <= end: bin_counts[i] += int(v) bin_sums[i] += int(v) * int(k) break assert all(c > 0 for c in bin_counts), f"Expected all bin_counts to be greater than 0, got {bin_counts}. Consider to re-design the bins {bins}." bin_centers = [s / c for s, c in zip(bin_sums, bin_counts)] return bin_centers, bin_counts def get_dataloader(args: ArgumentParser, split: str = "train") -> Union[Tuple[DataLoader, Union[DistributedSampler, None]], DataLoader]: ddp = args.nprocs > 1 if split == "train": # train, strong augmentation transforms = [ datasets.RandomResizedCrop((args.input_size, args.input_size), scale=(args.aug_min_scale, args.aug_max_scale)), datasets.RandomHorizontalFlip(), ] if args.aug_brightness > 0 or args.aug_contrast > 0 or args.aug_saturation > 0 or args.aug_hue > 0: transforms.append(datasets.ColorJitter( brightness=args.aug_brightness, contrast=args.aug_contrast, saturation=args.aug_saturation, hue=args.aug_hue )) if args.aug_blur_prob > 0 and args.aug_kernel_size > 0: transforms.append(datasets.RandomApply([ datasets.GaussianBlur(kernel_size=args.aug_kernel_size), ], p=args.aug_blur_prob)) if args.aug_saltiness > 0 or args.aug_spiciness > 0: transforms.append(datasets.PepperSaltNoise( saltiness=args.aug_saltiness, spiciness=args.aug_spiciness, )) transforms = Compose(transforms) elif args.sliding_window and args.resize_to_multiple: transforms = datasets.Resize2Multiple(args.window_size, stride=args.stride) else: transforms = None dataset_class = datasets.InMemoryCrowd if args.in_memory_dataset else datasets.Crowd prefetch_factor = None if args.num_workers == 0 else 3 persistent_workers = False if args.num_workers == 0 else True dataset = dataset_class( dataset=args.dataset, split=split, transforms=transforms, sigma=None, return_filename=False, num_crops=args.num_crops if split == "train" else 1, ) if ddp and split == "train": # data_loader for training in DDP sampler = DistributedSampler(dataset, num_replicas=args.nprocs, rank=args.local_rank, shuffle=True, seed=args.seed+args.local_rank) data_loader = DataLoader( dataset, batch_size=args.batch_size, sampler=sampler, num_workers=args.num_workers, pin_memory=True, collate_fn=datasets.collate_fn, prefetch_factor=prefetch_factor, persistent_workers=persistent_workers, ) return data_loader, sampler elif (not ddp) and split == "train": # data_loader for training data_loader = DataLoader( dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers, pin_memory=True, collate_fn=datasets.collate_fn, prefetch_factor=prefetch_factor, persistent_workers=persistent_workers, ) return data_loader, None elif ddp and split == "val": sampler = DistributedSampler(dataset, num_replicas=args.nprocs, rank=args.local_rank, shuffle=False) data_loader = DataLoader( dataset, batch_size=1, # Use batch size 1 for evaluation sampler=sampler, shuffle=False, num_workers=args.num_workers, pin_memory=True, collate_fn=datasets.collate_fn, prefetch_factor=prefetch_factor, persistent_workers=persistent_workers, ) return data_loader else: # (not ddp) and split == "val" data_loader = DataLoader( dataset, batch_size=1, # Use batch size 1 for evaluation shuffle=False, num_workers=args.num_workers, pin_memory=True, collate_fn=datasets.collate_fn, prefetch_factor=prefetch_factor, persistent_workers=persistent_workers, ) return data_loader