ZIP-B / losses /poisson_nll.py
Yiming-M's picture
2025-07-31 15:53 🐣
0ecb9aa verified
import torch
from torch import nn, Tensor
from .utils import _reshape_density
EPS = 1e-8
class PoissonNLL(nn.Module):
def __init__(
self,
reduction: str = "mean",
) -> None:
super().__init__()
assert reduction in ["none", "mean", "sum"], f"Expected reduction to be one of ['none', 'mean', 'sum'], got {reduction}."
self.reduction = reduction
def forward(self, pred_den_map: Tensor, gt_den_map: Tensor) -> Tensor:
"""
Args:
pred_den_map: predicted λ map, shape (B, 1, H, W)
gt_den_map: ground truth density map, shape (B, 1, H, W)
Returns:
Poisson loss
"""
assert len(pred_den_map.shape) == 4, f"Expected pred_den_map to have 4 dimensions, got {len(pred_den_map.shape)}"
assert len(gt_den_map.shape) == 4, f"Expected gt_den_map to have 4 dimensions, got {len(gt_den_map.shape)}"
assert pred_den_map.shape[1] == gt_den_map.shape[1] == 1, f"Expected pred_den_map and gt_den_map to have 1 channel, got {pred_den_map.shape[1]} and {gt_den_map.shape[1]}"
if gt_den_map.shape != pred_den_map.shape:
gt_h, gt_w = gt_den_map.shape[-2], gt_den_map.shape[-1]
pred_h, pred_w = pred_den_map.shape[-2], pred_den_map.shape[-1]
assert gt_h % pred_h == 0 and gt_w % pred_w == 0 and gt_h // pred_h == gt_w // pred_w, f"Expected the spatial dimension of gt_den_map to be a multiple of that of pred_den_map, got {gt_den_map.shape} and {pred_den_map.shape}"
gt_den_map = _reshape_density(gt_den_map, block_size=gt_h // pred_h)
assert gt_den_map.shape == pred_den_map.shape, f"Expected gt_den_map and pred_den_map to have the same shape, got {gt_den_map.shape} and {pred_den_map.shape}"
gt_den_map = gt_den_map.to(pred_den_map.device)
loss = (pred_den_map - gt_den_map * torch.log(pred_den_map + EPS)).sum(dim=(-1, -2)) # sum over H and W
if self.reduction == "mean":
loss = loss.mean()
elif self.reduction == "sum":
loss = loss.sum()
return loss, {"pnll": loss.detach()}