aditya-grover / climate-learn

Source code for ClimateLearn
MIT License
310 stars 49 forks source link

Model Refactor #88

Closed jasonjewik closed 1 year ago

jasonjewik commented 1 year ago

Model Refactor

Summary

The models in ClimateLearn are both cumbersome to use (no snappy way to load presets, climatology has to be manually set, baselines are baked into models) and insufficiently flexible (hard to add new models). In this issue, I outline proposed changes to resolve these issues. Ideally, I want the model quickstart to change from this:

dm = DataModule(...)
model_kwargs = {...}
optim_kwargs = {...}
mm = load_model(
    name="resnet",
    task="forecasting",
    model_kwargs=model_kwargs,
    optim_kwargs=optim_kwargs
)
set_climatology(mm, dm)
fit_lin_reg_baseline(mm, dm, reg_hparam=0.0)

To this:

dm = DataModule(...)
mm = load_forecasting_module(dm, preset="rasp-theurey-2020")
trainer = Trainer(...)
trainer.fit(mm, dm)
trainer.test(mm, dm)

Model Loading

Examples

Here, I show examples for load_forecasting_module. Everything is analogous for load_downscaling_module. The reason why I split the original load_model function by task is to mirror the data module design, where ForecastingArgs is a distinct class from DownscalingArgs.

  1. load_forecasting_module(data_module, preset)

    Loads a preset model module. For example, the following loads the model and optimizers described in this paper.

    load_forecasting_module(dm, preset="rasp-thuerey-2020")

    Presets also exist for baselines.

    load_forecasting_module(dm, preset="climatology")
    load_forecasting_module(dm, preset="persistence")
    load_forecasting_module(dm, preset="linear-regression")
  2. load_forecasting_module(data_module, preset, model_kwargs)

    Loads a preset model module. The user can also pass keyword arguments to modify the model architecture. For example, the following loads Rasp and Thuerey's model, but changes the dropout.

    load_forecasting_module(
        dm,
        preset="rasp-theurey-2020",
        model_kwargs={"dropout": 0.3}
    )
  3. load_forecasting_module(data_module, preset, model_kwargs, optim, optim_kwargs)

    Loads a preset model module. The user can also pass keyword arguments to modify the model architecture. They can also specify the name of an optimizer (which is built into ClimateLearn) and keyword arguments for the optimzer. For example,

    load_forecasting_module(
        dm,
        preset="rasp-theurey-2020",
        model_kwargs={"dropout": 0.3},
        optim="adamw",
        optim_kwargs={"betas": (0.9, 0.95)}
    )
  4. load_forecasting_module(data_module, preset, model_kwargs, optimizer)

    Loads a preset model module. The user can also pass keyword arguments to modify the model architecture. They can also specify an already instantiated optimizer. For exapmle,

    load_forecasting_module(
        dm,
        preset="rasp-theurey-2020",
        model_kwargs={"dropout": 0.3},
        optimizer=my_cool_optimizer
    )
  5. load_forecasting_module(data_module, model, model_kwargs, optim, optim_kwargs)

    Loads a model module with the given model and optimizer, which are defined in ClimateLearn but can be customized by model_kwargs and optim_kwargs. For example,

    load_forecasting_module(
        dm,
        model="resnet",
        model_kwargs={"n_blocks": 2},
        optim="adamw",
        optim_kwargs={"betas": (0.9, 0.95)}
    )
  6. load_forecasting_module(data_module, model, model_kwargs, optimizer)

    Loads a model module with the given model, which is defined in ClimateLearn but can be customized by model_kwargs. The optimizer is specified separately. For example:

    load_forecasting_module(
        deta_module,
        model="resnet",
        model_kwargs={"n_blocks": 2},
        optimizer=my_cool_optimizer
    )
  7. load_forecasting_module(data_module, net, optimizer)

    Loads a model module which wraps the user-specified network and optimizer. For example:

    load_forecasting_module(
        data_module,
        net=my_cool_network,
        optimizer=my_cool_optimizer,
    )

Function Signature

load_xxx_module(
    data_module: pl.LightningDataModule,
    preset: Optional[str] = None,
    model: Optional[str] = None,
    model_kwargs: Optional[Dict[str, Any]] = None,
    optim: Optional[str] = None,
    optim_kwargs: Optional[Dict[str, Any]] = None,
    net: Optional[torch.nn.Module] = None,
    optimizer: Optional[Union[torch.optim, Dict[str, torch.optim]]] = None,
    train_loss: Optional[Union[Callable, List[Callable]]] = None,
    val_loss: Optional[Union[Callable, List[Callable]]] = None,
    test_loss: Optional[Union[Callable, List[Callable]]] = None
)

Note that preset and model are aliases for each other. They are kept as two distinct arguments for the sake of clarity. For example the following two function calls return the same module:

But in the first case, it is more obvious that the user wants the model which has been defined in Rasp and Theurey (2020). If both preset and model are specified, a RuntimeError will be thrown. This is the same behavior as when net is passed even if model is specified, or any other argument conflicts.

The optimizer argument can either be a PyTorch optimizer or a dictionary which contains two keys: "optimizer" and "lr_scheduler". In the case that it is just a PyTorch optimizer, no scheduler is used for the optimization.

I also add arguments for specifying loss functions. If these are left as None, the default loss functions which are specified in ClimateLearn will be used. However, the user might want this flexibility. For example, someone might be interested in using the AtmoDist loss for downscaling.

How does this solve existing problems?

  1. The user can easily load presets. I've shown this for Rasp and Theurey, but we could also include ClimaX, Weyn et al. (2020), and others. Besides just loading the architectures, when possible, we can also load pre-trained models. For example, we could have both "climax", which loads the untrained ClimaX model, and "climax-pretrained", which loads the pre-trained ClimaX model.

  2. Climatology is set automatically. ClimateLearn requires climatology to be set before training. It doesn't make sense to require the user to remember to do this. Here, climatology is set in the load_xxx_module function. I show how this is done below.

  3. Baselines are not baked into models. As pointed out in Issue 83, it doesn't always make sense to run persistence because the data module might not support it. Furthermore, the user might not care to see these baselines. In my proposed changes, we separate out the baselines into their own models. If the user wishes to run climatology, persistence, or linear regression, they can do that the same as any other model. For example,

    load_forecasting_module(dm, preset="climatology")
  4. New models are easier to add. The user can modify ClimateLearn's presets (e.g., Rasp and Theurey, ClimaX) and built-in architectures (e.g., ResNet, ViT), and they can define their own network and/or optimizer and pass these to the load_xxx_module function. We can include a page in the documentation about what API is expected for forecasting networks versus downscaling networks.

Setting Climatology

In the load_xxx_module function, we can do the following to set climatology automatically.

def load_forecasting_module(dm, ...):
    # ...
    mm = ForecastingLitModule(...)
    mm.set_climatology(dm.get_climatology("all"))
    # ...

This relies upon pull request 81 being merged, and also a minor change to DataModule.get_climatology.

Baselines

For the persistence baseline, we can do the following to determine if it is available.

def load_forecasting_module(dm, ...):
    # ...
    if preset == "persistence":
        if set(dm.out_vars).issubset(dm.in_vars):
            mm = ForecastLitModule(...)
        else:
            raise RuntimeError()
    # ...

Again, this would require just a minor change so that the input variables and the output variables of the dataset are both available at the DataModule level.

Conclusion

In making these changes, I aim for the following two goals. First, to make it easier to run benchmark models. Second, to make it easier to add a custom model. The flexibility of my proposed API allows for a balance between these two goals.

prakhar6sharma commented 1 year ago

Overall the proposal looks great. I highly agree with the issues regarding the baselines.

My only concern here would be, can we somehow remove the dependency of model module upon the data module for instantiation purposes and let the trainer act as a mediator between these modules. In the longer run, for efficiency purposes, the data module would contain data only after its .setup() is called. This can cause different processes to have different copies of data and hence model may use the wrong copy. If we let the Trainer handle this then it can aggregate the information across process and then pass it on to the model module.

jasonjewik commented 1 year ago

Yes, I also want to remove the dependency between the model and data modules. The problem is that the model module needs to have climatology set because some of the metrics require it. The only way I can see around this issue is to override Trainer.test, but that feels like a separate, future PR to me.

Also, are my requested modifications to the data module possible?

prakhar6sharma commented 1 year ago

Also, are my requested modifications to the data module possible?

Yes, it should be straight forward.

jasonjewik commented 1 year ago

My only concern here would be, can we somehow remove the dependency of model module upon the data module for instantiation purposes and let the trainer act as a mediator between these modules.

@prakhar6sharma just want to follow up on this. Do you think it is acceptable that for the initial model refactor, I proceed with the current plan to use data module? I believe we can revisit this in a future PR.

prakhar6sharma commented 1 year ago

Feel free to go ahead with this plan. As you said we can revisit this in a future PR.