openclimatefix / skillful_nowcasting

Implementation of DeepMind's Deep Generative Model of Radar (DGMR) https://arxiv.org/abs/2104.00954
MIT License
223 stars 59 forks source link

Why use `grid_cell_regularizer` function over `GridCellLoss` class for grid cell regularization? #79

Open rutkovskii opened 17 hours ago

rutkovskii commented 17 hours ago

Hi,

I noticed that in the implementation, the grid_cell_regularizer function is used to compute the grid cell regularizer, but I was exploring using a GridCellLoss class instead. This class includes the weight function described in the paper and offers flexibility for modifications. I want to understand if there’s a specific reason for the choice of the function over a class-based implementation with w(y) = max(y + 1, 24).

In the paper the significance of it is worded as:

The grid cell regularize ensures that the mean prediction remains close to the ground truth, and is averaged across all grid cells along the height 𝐻, width 𝑊, and lead-time 𝑁 axes. It is weighted towards higher rainfall targets using the function w(y) = max(y + 1, 24), which operate element-wise for input vectors, and is clipped at 24 for robustness to spuriously large values in the radar.

Observations on my data

Here's the weight function implementation:

import torch

def weight_fn(y, precip_weight_cap=24.0):
    """
    Weight function for the grid cell loss.
    w(y) = max(y + 1, ceil)

    Args:
        y: Tensor of rainfall intensities.
        ceil: Custom ceiling for the weight function.

    Returns:
        Weights for each grid cell.
    """
    return torch.max(y + 1, torch.tensor(precip_weight_cap, device=y.device))

My adjusted implementation of the GridCellLoss class:

class GridCellLoss(nn.Module):
    """
    Grid Cell Regularizer loss from Skillful Nowcasting, see https://arxiv.org/pdf/2104.00954.pdf
    """

    # def __init__(self, weight_fn=None):
    #     super().__init__()
    #     self.weight_fn = weight_fn  # In Paper, weight_fn is max(y+1,24)

    def __init__(self, weight_fn=None, precip_weight_cap=24.0):
        """
        Initialize GridCellLoss.

        Args:
            weight_fn: A function to compute weights for the loss.
            ceil: Custom ceiling value for the weight function.
        """
        super().__init__()
        self.weight_fn = lambda y: weight_fn(y, precip_weight_cap) if weight_fn else None
        print("Acquired weight_fn: ", self.weight_fn)

    def forward(self, generated_images, targets):
        """
        Calculates the grid cell regularizer value, assumes generated images are the mean predictions from
        6 calls to the generater (Monte Carlo estimation of the expectations for the latent variable)

        Args:
            generated_images: Mean generated images from the generator
            targets: Ground truth future frames

        Returns:
            Grid Cell Regularizer term
        """
        difference = generated_images - targets
        if self.weight_fn is not None:
            difference *= self.weight_fn(targets)
        difference /= targets.size(1) * targets.size(3) * targets.size(4)  # 1/HWN
        return difference.mean()

Here's an example of how I integrated the class inside of training_step:

        # Compute Grid Cell Loss using GridCellLoss
        print("Computing Grid Cell Loss")
        gen_mean = torch.stack(predictions, dim=0).mean(dim=0)  # Mean over samples
        grid_cell_reg = self.grid_regularizer(gen_mean, future_images)
        # Original function commented out
        # grid_cell_reg = grid_cell_regularizer(torch.stack(predictions, dim=0), future_images)

For comparison, here’s the original grid_cell_regularizer function:

def grid_cell_regularizer(generated_samples, batch_targets):
    """Grid cell regularizer.

    Args:
      generated_samples: Tensor of size [n_samples, batch_size, 18, 256, 256, 1].
      batch_targets: Tensor of size [batch_size, 18, 256, 256, 1].

    Returns:
      loss: A tensor of shape [batch_size].
    """
    gen_mean = torch.mean(generated_samples, dim=0)
    weights = torch.clip(batch_targets, 0.0, 24.0)
    loss = torch.mean(torch.abs(gen_mean - batch_targets) * weights)
    return loss