I noticed that in the implementation, the grid_cell_regularizer function is used to compute the grid cell regularizer, but I was exploring using a GridCellLoss class instead. This class includes the weight function described in the paper and offers flexibility for modifications. I want to understand if there’s a specific reason for the choice of the function over a class-based implementation with w(y) = max(y + 1, 24).
In the paper the significance of it is worded as:
The grid cell regularize ensures that the mean prediction remains close to the ground truth, and is averaged across all grid cells along the height 𝐻, width 𝑊, and lead-time 𝑁 axes. It is weighted towards higher rainfall targets using the function w(y) = max(y + 1, 24), which operate element-wise for input vectors, and is clipped at 24 for robustness to spuriously large values in the radar.
Observations on my data
Using the grid_cell_regularizer function gave a train_grid_loss of approximately 150s.
Implementing a GridCellLoss class with the weight function ( w(y) = max(y + 1, 24) ), as described in the paper, resulted in a very low train_grid_loss, e.g., -4.7147e-05.
Here's the weight function implementation:
import torch
def weight_fn(y, precip_weight_cap=24.0):
"""
Weight function for the grid cell loss.
w(y) = max(y + 1, ceil)
Args:
y: Tensor of rainfall intensities.
ceil: Custom ceiling for the weight function.
Returns:
Weights for each grid cell.
"""
return torch.max(y + 1, torch.tensor(precip_weight_cap, device=y.device))
My adjusted implementation of the GridCellLoss class:
class GridCellLoss(nn.Module):
"""
Grid Cell Regularizer loss from Skillful Nowcasting, see https://arxiv.org/pdf/2104.00954.pdf
"""
# def __init__(self, weight_fn=None):
# super().__init__()
# self.weight_fn = weight_fn # In Paper, weight_fn is max(y+1,24)
def __init__(self, weight_fn=None, precip_weight_cap=24.0):
"""
Initialize GridCellLoss.
Args:
weight_fn: A function to compute weights for the loss.
ceil: Custom ceiling value for the weight function.
"""
super().__init__()
self.weight_fn = lambda y: weight_fn(y, precip_weight_cap) if weight_fn else None
print("Acquired weight_fn: ", self.weight_fn)
def forward(self, generated_images, targets):
"""
Calculates the grid cell regularizer value, assumes generated images are the mean predictions from
6 calls to the generater (Monte Carlo estimation of the expectations for the latent variable)
Args:
generated_images: Mean generated images from the generator
targets: Ground truth future frames
Returns:
Grid Cell Regularizer term
"""
difference = generated_images - targets
if self.weight_fn is not None:
difference *= self.weight_fn(targets)
difference /= targets.size(1) * targets.size(3) * targets.size(4) # 1/HWN
return difference.mean()
Here's an example of how I integrated the class inside of training_step:
# Compute Grid Cell Loss using GridCellLoss
print("Computing Grid Cell Loss")
gen_mean = torch.stack(predictions, dim=0).mean(dim=0) # Mean over samples
grid_cell_reg = self.grid_regularizer(gen_mean, future_images)
# Original function commented out
# grid_cell_reg = grid_cell_regularizer(torch.stack(predictions, dim=0), future_images)
For comparison, here’s the original grid_cell_regularizer function:
def grid_cell_regularizer(generated_samples, batch_targets):
"""Grid cell regularizer.
Args:
generated_samples: Tensor of size [n_samples, batch_size, 18, 256, 256, 1].
batch_targets: Tensor of size [batch_size, 18, 256, 256, 1].
Returns:
loss: A tensor of shape [batch_size].
"""
gen_mean = torch.mean(generated_samples, dim=0)
weights = torch.clip(batch_targets, 0.0, 24.0)
loss = torch.mean(torch.abs(gen_mean - batch_targets) * weights)
return loss
Hi,
I noticed that in the implementation, the
grid_cell_regularizer
function is used to compute the grid cell regularizer, but I was exploring using aGridCellLoss
class instead. This class includes the weight function described in the paper and offers flexibility for modifications. I want to understand if there’s a specific reason for the choice of the function over a class-based implementation withw(y) = max(y + 1, 24)
.In the paper the significance of it is worded as:
Observations on my data
grid_cell_regularizer
function gave atrain_grid_loss
of approximately 150s.GridCellLoss
class with the weight function( w(y) = max(y + 1, 24) )
, as described in the paper, resulted in a very lowtrain_grid_loss
, e.g., -4.7147e-05.Here's the weight function implementation:
My adjusted implementation of the
GridCellLoss
class:Here's an example of how I integrated the class inside of
training_step
:For comparison, here’s the original
grid_cell_regularizer
function: