| | import torch |
| | from torch import nn |
| | import numpy as np |
| | import os, json |
| | from tqdm import tqdm |
| | from argparse import ArgumentParser |
| | from typing import Dict |
| |
|
| | import datasets |
| |
|
| |
|
| | class SumPool2d(nn.Module): |
| | def __init__(self, kernel_size: int, stride: int): |
| | super(SumPool2d, self).__init__() |
| | self.kernel_size = kernel_size |
| | self.stride = stride |
| | self.sum_pool = nn.AvgPool2d(kernel_size, stride, divisor_override=1) |
| |
|
| | def forward(self, x): |
| | return self.sum_pool(x) |
| |
|
| |
|
| | def _update_dict(d: Dict, keys: np.ndarray, values: np.ndarray) -> Dict: |
| | keys = keys.tolist() if isinstance(keys, np.ndarray) else keys |
| | values = values.tolist() if isinstance(values, np.ndarray) else values |
| | for k, v in zip(keys, values): |
| | d[k] = d.get(k, 0) + v |
| |
|
| | return d |
| |
|
| |
|
| | def _get_counts( |
| | dataset_name: str, |
| | device: torch.device, |
| | ) -> None: |
| | filter_4 = SumPool2d(4, 1).to(device) |
| | filter_7 = SumPool2d(7, 1).to(device) |
| | filter_8 = SumPool2d(8, 1).to(device) |
| | filter_14 = SumPool2d(14, 1).to(device) |
| | filter_16 = SumPool2d(16, 1).to(device) |
| | filter_28 = SumPool2d(28, 1).to(device) |
| | filter_32 = SumPool2d(32, 1).to(device) |
| | filter_56 = SumPool2d(56, 1).to(device) |
| | filter_64 = SumPool2d(64, 1).to(device) |
| | counts_1, counts_4, counts_7, counts_8 = {}, {}, {}, {} |
| | counts_14, counts_16 = {}, {} |
| | counts_28, counts_32 = {}, {} |
| | counts_56, counts_64 = {}, {} |
| |
|
| | max_counts_4 = {"max": 0., "name": None, "x": None, "y": None} |
| | max_counts_7 = {"max": 0., "name": None, "x": None, "y": None} |
| | max_counts_8 = {"max": 0., "name": None, "x": None, "y": None} |
| | max_counts_14 = {"max": 0., "name": None, "x": None, "y": None} |
| | max_counts_16 = {"max": 0., "name": None, "x": None, "y": None} |
| | max_counts_28 = {"max": 0., "name": None, "x": None, "y": None} |
| | max_counts_32 = {"max": 0., "name": None, "x": None, "y": None} |
| | max_counts_56 = {"max": 0., "name": None, "x": None, "y": None} |
| | max_counts_64 = {"max": 0., "name": None, "x": None, "y": None} |
| |
|
| | counts_dir = os.path.join(os.getcwd(), "counts") |
| | os.makedirs(counts_dir, exist_ok=True) |
| |
|
| | dataset = datasets.Crowd(dataset=dataset_name, split="train", transforms=None, return_filename=True) |
| | print(f"Counting {dataset_name} dataset") |
| |
|
| | for i in tqdm(range(len(dataset))): |
| | _, _, density, img_name = dataset[i] |
| | density_np = density.cpu().numpy().astype(int) |
| | uniques_, counts_ = np.unique(density_np, return_counts=True) |
| | counts_1 = _update_dict(counts_1, uniques_, counts_) |
| |
|
| | density = density.to(device) |
| | window_4, window_7, window_8 = filter_4(density), filter_7(density), filter_8(density) |
| | window_14, window_16 = filter_14(density), filter_16(density) |
| | window_28, window_32 = filter_28(density), filter_32(density) |
| | window_56, window_64 = filter_56(density), filter_64(density) |
| |
|
| | window_4, window_7, window_8 = torch.round(window_4).int(), torch.round(window_7).int(), torch.round(window_8).int() |
| | window_14, window_16 = torch.round(window_14).int(), torch.round(window_16).int() |
| | window_28, window_32 = torch.round(window_28).int(), torch.round(window_32).int() |
| | window_56, window_64 = torch.round(window_56).int(), torch.round(window_64).int() |
| |
|
| | window_4, window_7, window_8 = torch.squeeze(window_4), torch.squeeze(window_7), torch.squeeze(window_8) |
| | window_14, window_16 = torch.squeeze(window_14), torch.squeeze(window_16) |
| | window_28, window_32 = torch.squeeze(window_28), torch.squeeze(window_32) |
| | window_56, window_64 = torch.squeeze(window_56), torch.squeeze(window_64) |
| |
|
| | if window_4.max().item() > max_counts_4["max"]: |
| | max_counts_4["max"] = window_4.max().item() |
| | max_counts_4["name"] = img_name |
| | x, y = torch.where(window_4 == window_4.max()) |
| | x, y = x[0].item(), y[0].item() |
| | max_counts_4["x"] = x |
| | max_counts_4["y"] = y |
| | |
| | if window_7.max().item() > max_counts_7["max"]: |
| | max_counts_7["max"] = window_7.max().item() |
| | max_counts_7["name"] = img_name |
| | x, y = torch.where(window_7 == window_7.max()) |
| | x, y = x[0].item(), y[0].item() |
| | max_counts_7["x"] = x |
| | max_counts_7["y"] = y |
| | |
| | if window_8.max().item() > max_counts_8["max"]: |
| | max_counts_8["max"] = window_8.max().item() |
| | max_counts_8["name"] = img_name |
| | x, y = torch.where(window_8 == window_8.max()) |
| | x, y = x[0].item(), y[0].item() |
| | max_counts_8["x"] = x |
| | max_counts_8["y"] = y |
| | |
| | if window_14.max().item() > max_counts_14["max"]: |
| | max_counts_14["max"] = window_14.max().item() |
| | max_counts_14["name"] = img_name |
| | x, y = torch.where(window_14 == window_14.max()) |
| | x, y = x[0].item(), y[0].item() |
| | max_counts_14["x"] = x |
| | max_counts_14["y"] = y |
| | |
| | if window_16.max().item() > max_counts_16["max"]: |
| | max_counts_16["max"] = window_16.max().item() |
| | max_counts_16["name"] = img_name |
| | x, y = torch.where(window_16 == window_16.max()) |
| | x, y = x[0].item(), y[0].item() |
| | max_counts_16["x"] = x |
| | max_counts_16["y"] = y |
| | |
| | if window_28.max().item() > max_counts_28["max"]: |
| | max_counts_28["max"] = window_28.max().item() |
| | max_counts_28["name"] = img_name |
| | x, y = torch.where(window_28 == window_28.max()) |
| | x, y = x[0].item(), y[0].item() |
| | max_counts_28["x"] = x |
| | max_counts_28["y"] = y |
| | |
| | if window_32.max().item() > max_counts_32["max"]: |
| | max_counts_32["max"] = window_32.max().item() |
| | max_counts_32["name"] = img_name |
| | x, y = torch.where(window_32 == window_32.max()) |
| | x, y = x[0].item(), y[0].item() |
| | max_counts_32["x"] = x |
| | max_counts_32["y"] = y |
| | |
| | if window_56.max().item() > max_counts_56["max"]: |
| | max_counts_56["max"] = window_56.max().item() |
| | max_counts_56["name"] = img_name |
| | x, y = torch.where(window_56 == window_56.max()) |
| | x, y = x[0].item(), y[0].item() |
| | max_counts_56["x"] = x |
| | max_counts_56["y"] = y |
| | |
| | if window_64.max().item() > max_counts_64["max"]: |
| | max_counts_64["max"] = window_64.max().item() |
| | max_counts_64["name"] = img_name |
| | x, y = torch.where(window_64 == window_64.max()) |
| | x, y = x[0].item(), y[0].item() |
| | max_counts_64["x"] = x |
| | max_counts_64["y"] = y |
| |
|
| | window_4 = window_4.view(-1).cpu().numpy().astype(int) |
| | window_7 = window_7.view(-1).cpu().numpy().astype(int) |
| | window_8 = window_8.view(-1).cpu().numpy().astype(int) |
| | window_14 = window_14.view(-1).cpu().numpy().astype(int) |
| | window_16 = window_16.view(-1).cpu().numpy().astype(int) |
| | window_28 = window_28.view(-1).cpu().numpy().astype(int) |
| | window_32 = window_32.view(-1).cpu().numpy().astype(int) |
| | window_56 = window_56.view(-1).cpu().numpy().astype(int) |
| | window_64 = window_64.view(-1).cpu().numpy().astype(int) |
| | |
| |
|
| | uniques_, counts_ = np.unique(window_4, return_counts=True) |
| | counts_4 = _update_dict(counts_4, uniques_, counts_) |
| |
|
| | uniques_, counts_ = np.unique(window_7, return_counts=True) |
| | counts_7 = _update_dict(counts_7, uniques_, counts_) |
| |
|
| | uniques_, counts_ = np.unique(window_8, return_counts=True) |
| | counts_8 = _update_dict(counts_8, uniques_, counts_) |
| |
|
| | uniques_, counts_ = np.unique(window_14, return_counts=True) |
| | counts_14 = _update_dict(counts_14, uniques_, counts_) |
| |
|
| | uniques_, counts_ = np.unique(window_16, return_counts=True) |
| | counts_16 = _update_dict(counts_16, uniques_, counts_) |
| |
|
| | uniques_, counts_ = np.unique(window_28, return_counts=True) |
| | counts_28 = _update_dict(counts_28, uniques_, counts_) |
| |
|
| | uniques_, counts_ = np.unique(window_32, return_counts=True) |
| | counts_32 = _update_dict(counts_32, uniques_, counts_) |
| |
|
| | uniques_, counts_ = np.unique(window_56, return_counts=True) |
| | counts_56 = _update_dict(counts_56, uniques_, counts_) |
| |
|
| | uniques_, counts_ = np.unique(window_64, return_counts=True) |
| | counts_64 = _update_dict(counts_64, uniques_, counts_) |
| |
|
| | counts = { |
| | 1: counts_1, |
| | 4: counts_4, |
| | 7: counts_7, |
| | 8: counts_8, |
| | 14: counts_14, |
| | 16: counts_16, |
| | 28: counts_28, |
| | 32: counts_32, |
| | 56: counts_56, |
| | 64: counts_64 |
| | } |
| |
|
| | max_counts = { |
| | 4: max_counts_4, |
| | 7: max_counts_7, |
| | 8: max_counts_8, |
| | 14: max_counts_14, |
| | 16: max_counts_16, |
| | 28: max_counts_28, |
| | 32: max_counts_32, |
| | 56: max_counts_56, |
| | 64: max_counts_64 |
| | } |
| |
|
| | with open(os.path.join(counts_dir, f"{dataset_name}.json"), "w") as f: |
| | json.dump(counts, f) |
| | |
| | with open(os.path.join(counts_dir, f"{dataset_name}_max.json"), "w") as f: |
| | json.dump(max_counts, f) |
| |
|
| |
|
| | def parse_args(): |
| | parser = ArgumentParser(description="Get local counts of the dataset") |
| | parser.add_argument( |
| | "--dataset", |
| | type=str, |
| | choices=["nwpu", "ucf_qnrf", "shanghaitech_a", "shanghaitech_b"], |
| | required=True, |
| | help="The dataset to use." |
| | ) |
| | parser.add_argument( |
| | "--device", |
| | type=str, |
| | default="cuda", |
| | help="The device to use." |
| | ) |
| | args = parser.parse_args() |
| | return args |
| |
|
| |
|
| | if __name__ == "__main__": |
| | args = parse_args() |
| | args.dataset = datasets.standardize_dataset_name(args.dataset) |
| | args.device = torch.device(args.device) |
| | _get_counts(args.dataset, args.device) |
| |
|