File size: 3,990 Bytes
0e40641
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
import torch
from torch.amp import autocast
import torch.nn.functional as F
import torch.distributed as dist
from torch import nn, Tensor
from torch.utils.data import DataLoader
from typing import Tuple, Optional
from tqdm import tqdm
import numpy as np

from utils import sliding_window_predict, barrier, calculate_errors


def evaluate(
    model: nn.Module,
    data_loader: DataLoader,
    sliding_window: bool,
    max_input_size: int = 4096,
    window_size: int = 224,
    stride: int = 224,
    max_num_windows: int = 64,
    device: torch.device = torch.device("cuda"),
    amp: bool = False,
    local_rank: int = 0,
    nprocs: int = 1,
    progress_bar: bool = True,
) -> Tuple[Tensor, Tensor]:
    ddp = nprocs > 1
    model = model.to(device)
    model.eval()
    pred_counts, gt_counts = [], []
    data_iter = tqdm(data_loader) if (local_rank == 0 and progress_bar) else data_loader

    for image, gt_points, _ in data_iter:
        image = image.to(device)
        image_height, image_width = image.shape[-2:]
        gt_counts.extend([len(p) for p in gt_points])

        # Resize image if it's smaller than the window size
        aspect_ratio = image_width / image_height
        if image_height < window_size:
            new_height = window_size
            new_width = int(new_height * aspect_ratio)
            image = F.interpolate(image, size=(new_height, new_width), mode="bicubic", align_corners=False)
            image_height, image_width = new_height, new_width
        if image_width < window_size:
            new_width = window_size
            new_height = int(new_width / aspect_ratio)
            image = F.interpolate(image, size=(new_height, new_width), mode="bicubic", align_corners=False)
            image_height, image_width = new_height, new_width

        with torch.set_grad_enabled(False), autocast(device_type="cuda", enabled=amp):
            if sliding_window or (image_height * image_width) > max_input_size ** 2:
                pred_den_maps = sliding_window_predict(model, image, window_size, stride, max_num_windows)
            else:
                pred_den_maps = model(image)

            pred_counts.extend(pred_den_maps.sum(dim=(-1, -2, -3)).cpu().numpy().tolist())
    
    barrier(ddp)
    assert len(pred_counts) == len(gt_counts), f"Length of predictions and ground truths should be equal, but got {len(pred_counts)} and {len(gt_counts)}"

    if ddp:
        pred_counts, gt_counts = torch.tensor(pred_counts, device=device), torch.tensor(gt_counts, device=device)
        # Pad `pred_counts` and `gt_counts` to the same length across all processes.
        local_length = torch.tensor([len(pred_counts)], device=device)
        lengths = [torch.zeros_like(local_length) for _ in range(nprocs)]
        dist.all_gather(lengths, local_length)
        max_length = max([l.item() for l in lengths])
        padded_pred_counts, padded_gt_counts = torch.full((max_length,), float("nan"), device=device), torch.full((max_length,), float("nan"), device=device)
        padded_pred_counts[:len(pred_counts)], padded_gt_counts[:len(gt_counts)] = pred_counts, gt_counts
        gathered_pred_counts, gathered_gt_counts = [torch.zeros_like(padded_pred_counts) for _ in range(nprocs)], [torch.zeros_like(padded_gt_counts) for _ in range(nprocs)]
        dist.all_gather(gathered_pred_counts, padded_pred_counts)
        dist.all_gather(gathered_gt_counts, padded_gt_counts)
        # Concatenate predictions and ground truths from all processes and remove padding (nan values).
        pred_counts, gt_counts = torch.cat(gathered_pred_counts).cpu(), torch.cat(gathered_gt_counts).cpu()
        pred_counts, gt_counts = pred_counts[~torch.isnan(pred_counts)], gt_counts[~torch.isnan(gt_counts)]
        pred_counts, gt_counts = pred_counts.numpy(), gt_counts.numpy()

    else:
        pred_counts, gt_counts = np.array(pred_counts), np.array(gt_counts)

    torch.cuda.empty_cache()
    return calculate_errors(pred_counts, gt_counts)