ZIP-T / datasets /crowd.py
Yiming-M's picture
2025-07-31 16:59 🐣
0e40641 verified
import torch
from torch import Tensor
from torch.utils.data import Dataset
from torchvision.transforms import ToTensor, Normalize, Compose
import os
from glob import glob
from tqdm import tqdm
# from PIL import Image
from turbojpeg import TurboJPEG, TJPF_RGB
jpeg_decoder = TurboJPEG()
import numpy as np
from typing import Optional, Callable, Union, Tuple
from .utils import get_id, generate_density_map
curr_dir = os.path.dirname(os.path.abspath(__file__))
available_datasets = [
"shanghaitech_a", "sha",
"shanghaitech_b", "shb",
"shanghaitech", "sh",
"ucf_qnrf", "qnrf", "ucf-qnrf",
"nwpu", "nwpu_crowd", "nwpu-crowd",
]
mean = (0.48145466, 0.4578275, 0.40821073)
std = (0.26862954, 0.26130258, 0.27577711)
def standardize_dataset_name(dataset: str) -> str:
assert dataset.lower() in available_datasets, f"Dataset {dataset} is not available."
if dataset.lower() in ["shanghaitech_a", "sha"]:
return "sha"
elif dataset.lower() in ["shanghaitech_b", "shb"]:
return "shb"
elif dataset.lower() in ["shanghaitech", "sh"]:
return "sh"
elif dataset.lower() in ["ucf_qnrf", "qnrf", "ucf-qnrf"]:
return "qnrf"
else:
assert dataset.lower() in ["nwpu", "nwpu_crowd", "nwpu-crowd"], f"Dataset {dataset} is not available."
return "nwpu"
class Crowd(Dataset):
def __init__(
self,
dataset: str,
split: str,
transforms: Optional[Callable] = None,
sigma: Optional[float] = None,
return_filename: bool = False,
num_crops: int = 1,
) -> None:
"""
Dataset for crowd counting.
"""
assert dataset.lower() in available_datasets, f"Dataset {dataset} is not available."
assert dataset.lower() not in ["shanghaitech", "sh"], "For the combined ShanghaiTech dataset, use ShanghaiTech class."
assert split in ["train", "val", "test"], f"Split {split} is not available."
assert num_crops > 0, f"num_crops should be positive, got {num_crops}."
self.dataset = standardize_dataset_name(dataset)
self.split = split
self.__find_root__()
self.__make_dataset__()
self.__check_sanity__()
self.to_tensor = ToTensor()
self.normalize = Normalize(mean=mean, std=std)
self.transforms = transforms
self.sigma = sigma
self.return_filename = return_filename
self.num_crops = num_crops
def __find_root__(self) -> None:
self.root = os.path.join(curr_dir, "..", "data", self.dataset)
def __make_dataset__(self) -> None:
image_names = glob(os.path.join(self.root, self.split, "images", "*.jpg"))
label_names = glob(os.path.join(self.root, self.split, "labels", "*.npy"))
image_names = [os.path.basename(image_name) for image_name in image_names]
label_names = [os.path.basename(label_name) for label_name in label_names]
image_names.sort(key=get_id)
label_names.sort(key=get_id)
image_ids = tuple([get_id(image_name) for image_name in image_names])
label_ids = tuple([get_id(label_name) for label_name in label_names])
assert image_ids == label_ids, "image_ids and label_ids do not match."
self.image_names = tuple(image_names)
self.label_names = tuple(label_names)
def __check_sanity__(self) -> None:
if self.dataset == "sha":
if self.split == "train":
assert len(self.image_names) == len(self.label_names) == 300, f"ShanghaiTech_A train split should have 300 images, but found {len(self.image_names)}."
else:
assert self.split == "val", f"Split {self.split} is not available for dataset {self.dataset}."
assert len(self.image_names) == len(self.label_names) == 182, f"ShanghaiTech_A val split should have 182 images, but found {len(self.image_names)}."
elif self.dataset == "shb":
if self.split == "train":
assert len(self.image_names) == len(self.label_names) == 399, f"ShanghaiTech_B train split should have 399 images, but found {len(self.image_names)}."
else:
assert self.split == "val", f"Split {self.split} is not available for dataset {self.dataset}."
assert len(self.image_names) == len(self.label_names) == 316, f"ShanghaiTech_B val split should have 316 images, but found {len(self.image_names)}."
elif self.dataset == "nwpu":
if self.split == "train":
assert len(self.image_names) == len(self.label_names) == 3109, f"NWPU train split should have 3109 images, but found {len(self.image_names)}."
else:
assert self.split == "val", f"Split {self.split} is not available for dataset {self.dataset}."
assert len(self.image_names) == len(self.label_names) == 500, f"NWPU val split should have 500 images, but found {len(self.image_names)}."
elif self.dataset == "qnrf":
if self.split == "train":
assert len(self.image_names) == len(self.label_names) == 1201, f"UCF_QNRF train split should have 1201 images, but found {len(self.image_names)}."
else:
assert self.split == "val", f"Split {self.split} is not available for dataset {self.dataset}."
assert len(self.image_names) == len(self.label_names) == 334, f"UCF_QNRF val split should have 334 images, but found {len(self.image_names)}."
def __len__(self) -> int:
return len(self.image_names)
def __getitem__(self, idx: int) -> Union[Tuple[Tensor, Tensor, Tensor], Tuple[Tensor, Tensor, Tensor, str]]:
image_name = self.image_names[idx]
label_name = self.label_names[idx]
image_path = os.path.join(self.root, self.split, "images", image_name)
label_path = os.path.join(self.root, self.split, "labels", label_name)
with open(image_path, "rb") as f:
# image = Image.open(f).convert("RGB")
image = jpeg_decoder.decode(f.read(), pixel_format=TJPF_RGB)
image = self.to_tensor(image)
with open(label_path, "rb") as f:
label = np.load(f)
label = torch.from_numpy(label).float()
if self.transforms is not None:
images_labels = [self.transforms(image.clone(), label.clone()) for _ in range(self.num_crops)]
images, labels = zip(*images_labels)
else:
images = [image.clone() for _ in range(self.num_crops)]
labels = [label.clone() for _ in range(self.num_crops)]
images = [self.normalize(img) for img in images]
density_maps = torch.stack([generate_density_map(label, image.shape[-2], image.shape[-1], sigma=self.sigma) for image, label in zip(images, labels)], 0)
image_names = [image_name] * len(images)
images = torch.stack(images, 0)
if self.return_filename:
return images, labels, density_maps, image_names
else:
return images, labels, density_maps
class InMemoryCrowd(Dataset):
def __init__(
self,
dataset: str,
split: str,
transforms: Optional[Callable] = None,
sigma: Optional[float] = None,
return_filename: bool = False,
num_crops: int = 1,
) -> None:
"""
Dataset for crowd counting, with images and labels loaded into memory.
"""
crowd = Crowd(
dataset=dataset,
split=split,
transforms=None,
sigma=sigma,
return_filename=True,
num_crops=1,
)
print(f"Loading {len(crowd)} samples from {dataset} {split} split into memory...")
self.images, self.labels, self.image_names = [], [], []
self.unnormalize = Compose([
Normalize(mean=(0., 0., 0.), std=(1./std[0], 1./std[1], 1./std[2]), inplace=True),
Normalize(mean=(-mean[0], -mean[1], -mean[2]), std=(1., 1., 1.), inplace=True)
])
for i in tqdm(range(len(crowd)), desc="Loading images and labels into memory"):
image, label, _, image_name = crowd[i]
self.images.append(self.unnormalize(image[0])) # recover original image
self.labels.append(label[0])
self.image_names.append(image_name[0])
assert len(self.images) == len(self.labels) == len(self.image_names), "Mismatch in number of images, labels, and image names."
self.transforms = transforms
self.sigma = sigma
self.num_crops = num_crops
self.return_filename = return_filename
self.normalize = Normalize(mean=mean, std=std, inplace=False)
def __len__(self) -> int:
return len(self.images)
def __getitem__(self, idx: int) -> Union[Tuple[Tensor, Tensor, Tensor], Tuple[Tensor, Tensor, Tensor, str]]:
image, label, image_name = self.images[idx].clone(), self.labels[idx].clone(), self.image_names[idx]
if self.transforms is not None:
images_labels = [self.transforms(image.clone(), label.clone()) for _ in range(self.num_crops)]
images, labels = zip(*images_labels)
else:
images = [image.clone() for _ in range(self.num_crops)]
labels = [label.clone() for _ in range(self.num_crops)]
images = [self.normalize(img) for img in images]
density_maps = torch.stack([generate_density_map(label, image.shape[-2], image.shape[-1], sigma=self.sigma) for image, label in zip(images, labels)], 0)
image_names = [image_name] * len(images)
images = torch.stack(images, 0)
if self.return_filename:
return images, labels, density_maps, image_names
else:
return images, labels, density_maps
class NWPUTest(Dataset):
def __init__(
self,
transforms: Optional[Callable] = None,
return_filename: bool = False,
) -> None:
"""
The test set of NWPU-Crowd dataset. The test set is not labeled, so only images are returned.
"""
self.root = os.path.join(curr_dir, "..", "data", "nwpu")
image_names = glob(os.path.join(self.root, "test", "images", "*.jpg"))
image_names = [os.path.basename(image_name) for image_name in image_names]
assert len(image_names) == 1500, f"NWPU test split should have 1500 images, but found {len(image_names)}."
image_names.sort(key=get_id)
self.image_names = tuple(image_names)
self.to_tensor = ToTensor()
self.normalize = Normalize(mean=mean, std=std)
self.transforms = transforms
self.return_filename = return_filename
def __len__(self) -> int:
return len(self.image_names)
def __getitem__(self, idx: int) -> Union[Tensor, Tuple[Tensor, str]]:
image_name = self.image_names[idx]
image_path = os.path.join(self.root, "test", "images", image_name)
with open(image_path, "rb") as f:
# image = Image.open(f).convert("RGB")
image = jpeg_decoder.decode(f.read(), pixel_format=TJPF_RGB)
image = self.to_tensor(image)
label = torch.tensor([], dtype=torch.float) # dummy label
image, _ = self.transforms(image, label) if self.transforms is not None else (image, label)
image = self.normalize(image)
if self.return_filename:
return image, image_name
else:
return image
class ShanghaiTech(Dataset):
def __init__(
self,
split: str,
transforms: Optional[Callable] = None,
sigma: Optional[float] = None,
return_filename: bool = False,
num_crops: int = 1,
) -> None:
super().__init__()
self.sha = Crowd(
dataset="sha",
split=split,
transforms=transforms,
sigma=sigma,
return_filename=return_filename,
num_crops=num_crops,
)
self.shb = Crowd(
dataset="shb",
split=split,
transforms=transforms,
sigma=sigma,
return_filename=return_filename,
num_crops=num_crops,
)
self.dataset = "sh"
self.split = split
self.transforms = transforms
self.sigma = sigma
self.return_filename = return_filename
self.num_crops = num_crops
def __len__(self) -> int:
return len(self.sha) + len(self.shb)
def __getitem__(self, idx: int) -> Union[Tuple[Tensor, Tensor, Tensor], Tuple[Tensor, Tensor, Tensor, str]]:
if idx < len(self.sha):
return self.sha[idx]
else:
return self.shb[idx - len(self.sha)]