microsoft / torchgeo

TorchGeo: datasets, samplers, transforms, and pre-trained models for geospatial data
https://www.osgeo.org/projects/torchgeo/
MIT License
2.67k stars 328 forks source link

Add new pretrained weights #1043

Open calebrob6 opened 1 year ago

calebrob6 commented 1 year ago

Summary

This issue is to track progress on implementing new pretrained weights from related literature into torchgeo:

and many, many more:

Rationale

Foundation Models are one of the most substantial developments in recent ML research. FMs trained on ImageNet are one of the core components of torchvision and transformers that make them so popular. TorchGeo serves as a collection of EO FMs, allowing researchers to quickly and easily experiment with and design new FMs. This is critical for researchers to apply FMs to transfer learning on downstream tasks with small labeled datasets.

Implementation

See #2057, #1903, #1884, etc. for recent PRs adding new FMs.

If you would like to volunteer to add a particular FM, please comment on this issue to say that you're working on this.

Not sure where to get started? FMs that can be considered "multi-modal" (the same set of pre-trained weights can dynamically handle imagery from many different sensors) are the highest priority!

nilsleh commented 1 year ago
  • Which of the 100k or 1M models were implemented?

1 M models were implemented

  • This is partially done (@nilsleh do you mind commenting which you didn't import and why?)

From there we are missing MAE and Date2Vec. I remember when first trying to extract the state dict for those, I was running into some weird issue. And then I forgot about it, but I will look into it again!

nilsleh commented 1 year ago

Regarding these weights, there are three things I am running into:

  1. Their checkpoint contains model, optimizer states and more so at some point you would have to do state_dict = state_dict["model"] whereas all our other pretrained weights are just the model checkpoint already
  2. Their checkpoint file was saved from GPU, and I actually don't know where I can add the map_location=torch.device("cpu") argument except in the torchvision source code. We are calling get_state_dict() but it does not accept an argument for map_location but load_state_dict_from_url() does. So I think we need to open an issue with torchvision?
  3. several of their models are ViTAE models and I cannot find those in timm

It feels like the above two issues could happen again down the line with other pretrained weights as well, if they don't have the license for us to upload them to huggingface in the format we prefer.

adamjstewart commented 1 year ago

For 1 and 2, my preference would be to modify and save only the backbone on the CPU so that users or Lightning can map it to the GPU themselves. But of course that requires a favorable license. I would start by inquiring about licenses. If there's an issue with the license we can suggest changing it and describe the use case we have in mind. If they don't respond or won't change, then and only then can we think about more complex code in TorchGeo or changes to torchvision.

calebrob6 commented 1 year ago

Yep I think re-save the weights is the way to go.

adamjstewart commented 8 months ago

The list to end all lists: https://github.com/Jack-bo1220/Awesome-Remote-Sensing-Foundation-Models

wangyi111 commented 4 months ago

I will work on adding bigger ViTs from SSL4EO-S12, FG-MAE, DeCUR, and SoftCon these days.