|
|
import torch |
|
|
from torch.amp import autocast |
|
|
import torch.nn.functional as F |
|
|
import torch.distributed as dist |
|
|
from torch import nn, Tensor |
|
|
from torch.utils.data import DataLoader |
|
|
from typing import Tuple, Optional |
|
|
from tqdm import tqdm |
|
|
import numpy as np |
|
|
|
|
|
from utils import sliding_window_predict, barrier, calculate_errors |
|
|
|
|
|
|
|
|
def evaluate( |
|
|
model: nn.Module, |
|
|
data_loader: DataLoader, |
|
|
sliding_window: bool, |
|
|
max_input_size: int = 4096, |
|
|
window_size: int = 224, |
|
|
stride: int = 224, |
|
|
max_num_windows: int = 64, |
|
|
device: torch.device = torch.device("cuda"), |
|
|
amp: bool = False, |
|
|
local_rank: int = 0, |
|
|
nprocs: int = 1, |
|
|
progress_bar: bool = True, |
|
|
) -> Tuple[Tensor, Tensor]: |
|
|
ddp = nprocs > 1 |
|
|
model = model.to(device) |
|
|
model.eval() |
|
|
pred_counts, gt_counts = [], [] |
|
|
data_iter = tqdm(data_loader) if (local_rank == 0 and progress_bar) else data_loader |
|
|
|
|
|
for image, gt_points, _ in data_iter: |
|
|
image = image.to(device) |
|
|
image_height, image_width = image.shape[-2:] |
|
|
gt_counts.extend([len(p) for p in gt_points]) |
|
|
|
|
|
|
|
|
aspect_ratio = image_width / image_height |
|
|
if image_height < window_size: |
|
|
new_height = window_size |
|
|
new_width = int(new_height * aspect_ratio) |
|
|
image = F.interpolate(image, size=(new_height, new_width), mode="bicubic", align_corners=False) |
|
|
image_height, image_width = new_height, new_width |
|
|
if image_width < window_size: |
|
|
new_width = window_size |
|
|
new_height = int(new_width / aspect_ratio) |
|
|
image = F.interpolate(image, size=(new_height, new_width), mode="bicubic", align_corners=False) |
|
|
image_height, image_width = new_height, new_width |
|
|
|
|
|
with torch.set_grad_enabled(False), autocast(device_type="cuda", enabled=amp): |
|
|
if sliding_window or (image_height * image_width) > max_input_size ** 2: |
|
|
pred_den_maps = sliding_window_predict(model, image, window_size, stride, max_num_windows) |
|
|
else: |
|
|
pred_den_maps = model(image) |
|
|
|
|
|
pred_counts.extend(pred_den_maps.sum(dim=(-1, -2, -3)).cpu().numpy().tolist()) |
|
|
|
|
|
barrier(ddp) |
|
|
assert len(pred_counts) == len(gt_counts), f"Length of predictions and ground truths should be equal, but got {len(pred_counts)} and {len(gt_counts)}" |
|
|
|
|
|
if ddp: |
|
|
pred_counts, gt_counts = torch.tensor(pred_counts, device=device), torch.tensor(gt_counts, device=device) |
|
|
|
|
|
local_length = torch.tensor([len(pred_counts)], device=device) |
|
|
lengths = [torch.zeros_like(local_length) for _ in range(nprocs)] |
|
|
dist.all_gather(lengths, local_length) |
|
|
max_length = max([l.item() for l in lengths]) |
|
|
padded_pred_counts, padded_gt_counts = torch.full((max_length,), float("nan"), device=device), torch.full((max_length,), float("nan"), device=device) |
|
|
padded_pred_counts[:len(pred_counts)], padded_gt_counts[:len(gt_counts)] = pred_counts, gt_counts |
|
|
gathered_pred_counts, gathered_gt_counts = [torch.zeros_like(padded_pred_counts) for _ in range(nprocs)], [torch.zeros_like(padded_gt_counts) for _ in range(nprocs)] |
|
|
dist.all_gather(gathered_pred_counts, padded_pred_counts) |
|
|
dist.all_gather(gathered_gt_counts, padded_gt_counts) |
|
|
|
|
|
pred_counts, gt_counts = torch.cat(gathered_pred_counts).cpu(), torch.cat(gathered_gt_counts).cpu() |
|
|
pred_counts, gt_counts = pred_counts[~torch.isnan(pred_counts)], gt_counts[~torch.isnan(gt_counts)] |
|
|
pred_counts, gt_counts = pred_counts.numpy(), gt_counts.numpy() |
|
|
|
|
|
else: |
|
|
pred_counts, gt_counts = np.array(pred_counts), np.array(gt_counts) |
|
|
|
|
|
torch.cuda.empty_cache() |
|
|
return calculate_errors(pred_counts, gt_counts) |
|
|
|