mllam / neural-lam

Neural Weather Prediction for Limited Area Modeling
MIT License
64 stars 24 forks source link

Feature Request: Add Functionality to Apply Constraints to Predictions #19

Open sadamov opened 1 month ago

sadamov commented 1 month ago

I am proposing the addition of a new method to our model class, designed to apply constraints to predictions to ensure that the values fall within specified bounds. This functionality would be useful for maintaining the integrity of our model's predictions in scenarios where certain variables have predefined limits.

Proposed Method:

def apply_constraints(self, prediction):
    """
    Apply constraints to prediction to ensure values are within the
    specified bounds
    """
    for param, (min_val, max_val) in constants.PARAM_CONSTRAINTS.items():
        indices = self.variable_indices[param]
        for index in indices:
            # Apply clamping to ensure values are within the specified
            # bounds
            prediction[:, :, :, index] = torch.clamp(
                prediction[:, :, :, index],
                min=min_val,
                max=max_val if max_val is not None else float("inf"),
            )
    return prediction

Rationale:

Data Integrity: Ensuring that predictions adhere to real-world constraints is essential for the reliability of our model's outputs. This method would allow us to enforce these constraints directly on the prediction tensor. Flexibility: By dynamically applying constraints based on the variable indices, we can maintain a high degree of flexibility in how we handle different variables with varying constraints.

The method could be added to the ARModel class, which is our primary model class. The constants.PARAM_CONSTRAINTS dictionary, which maps variable names to their minimum and maximum values, should be used to determine the constraints for each variable.

PARAM_CONSTRAINTS = {
    "RELHUM": (0, 100),
    "CLCT": (0, 100),
    "TOT_PREC": (0, None),
}

This feature is closely related to #18

joeloskarsson commented 1 month ago

I would want to do this by constraining the model output itself, also for what is used during training (a clamp does not give gradient). Specifically, variables with constraints should be handled by rescaling a sigmoid output or a softplus (for > 0). As this applies to all models, it would be nice with it included on a high level, so it does not have to be handled separately in each model.

One option would be to have a function constrain_outputs in the ARModel class that applies these activation functions. Then everything that inherits from this can just apply this to the state after making a prediction.