bair-climate-initiative / scale-mae

Make your models invariant to changes in scale.
https://ai-climate.berkeley.edu/scale-mae-website/
Other
121 stars 13 forks source link

Running the model #1

Closed calebrob6 closed 1 year ago

calebrob6 commented 1 year ago

Hey all, I'm trying to run the model on new imagery with the following code where checkpoint = scalemae-vitlarge-800.pth, gsd_ratio set to the spatial resolution of the dataset, and channel-wise standardization preprocessing with dataset statistics (instead of, e.g., ImageNet statistics). Using this and the EuroSAT implementation in torchgeo, I'm able to get 0.954 test acc with a KNN-5 w/ gsd_ratio=1 (which is close to the 0.960 reported in your paper), however performance degrades to 0.876 with gsd_ratio=10. Any ideas where I'm going wrong?

import torch.nn as nn
from src.models_vit import vit_large_patch16
from src.pos_embed import interpolate_pos_embed

class ScaleMAEWrapper(nn.Module):
    def __init__(self, checkpoint, gsd_ratio):
        super().__init__()
        self.model = vit_large_patch16(
            img_size=224,
            num_classes=0,
            global_pool="avg"
        )
        checkpoint = torch.load(checkpoint, map_location="cpu")
        checkpoint_model = checkpoint["model"]
        state_dict = self.model.state_dict()
        for k in ["head.weight", "head.bias"]:
            if (
                k in checkpoint_model
                and checkpoint_model[k].shape != state_dict[k].shape
            ):
                print(f"Removing key {k} from pretrained checkpoint")
                del checkpoint_model[k]

        interpolate_pos_embed(self.model, checkpoint_model)            
        self.model.load_state_dict(checkpoint_model, strict=False)

        self.res = nn.Parameter(torch.ones(1).float() * gsd_ratio)

    def forward(self, x):
        return self.model(x, self.res)

Related, would you all be amenable to posting the pre-trained ViTs you have without the GSD encoder?

(@cjrd @RitwikGupta @jacklishufan , and @isaaccorley for visibility)

jacklishufan commented 1 year ago

Hi We are not planning to release ViTs with no GSD encoder in the immediate future. In principle, you can always set self.res=1.0 and use it as a GSD-free encoder. For your question, as mentioned in the paper the GSD ratio is how GSD is compared to certain reference GSD, not the absolute number in meters/pixel. In our multiscale experiment, we set the res to be gsd_ratio * (224 / eval_scale) where gsd_ratio = 1.0 by default. So the res passed into forward will be 1.0 at 100% res (224) and 2.0 at half resolution (112).

Generally, this mechanism is there to ensure that a sinewave covers the same spatial distance. For example, if the frequency is 1 cycle / 224 pixels after resizing the image to 112 we would have 1 cycle/112 pixels or 2 cycle/ 224 pixels. If the GSD is 10m / pixel at 224, then it will be 20m/pixel after downsampling the image to 112. However, the effective frequency in both cases will be 1cycle / 2240 meters

calebrob6 commented 1 year ago

Okay, so if I'm passing EuroSAT images that are 64x64 by default with GDS 10m/px, but rescaled to 224x224 pixels, what should gsd_ratio (or res) be?

calebrob6 commented 1 year ago

Here is a self-contained code-sample for reproducing the numbers I mentioned by the way -- https://gist.github.com/calebrob6/912c2509de9d94ad6bc924420eca40bb.

I've tried a few different things from https://github.com/bair-climate-initiative/scale-mae/blob/main/mae/eval/knn.py#L39 but am not able to get the 96.0 top1 acc.

jacklishufan commented 1 year ago

Hi. We rerun the evaluation and were able to reproduce the result use the vanilla evaluation command

python3 -m torch.distributed.launch --nproc_per_node=4  --master_port 33445 scripts/eval_launcher.py  --eval_config scripts/evalconf/dgx-conf.yaml --knn 5

Where our config file is of the format

# Specify the root directory that contains your checkpoints
# Specify the root directory that contains your checkpoints
root: <root folder name>

exp_ids:
  - <sub dir contains checkpoint-latest.pth >

evals:
  - id: eurosat
    path: <data root>
    scales:
      - 16
      - 32
      - 64

Logs

[05:02:26.845328] Starting KNN evaluation with K=5
0it [00:00, ?it/s][05:02:30.034698] Eval data mean (should be near 0): tensor(-0.0519)
[05:02:30.037960] Eval data std (should be near 1): tensor(0.9873)
85it [00:21,  4.03it/s]
[05:02:48.603433] distributed world size: 4
[05:02:48.607471] Grabbing all kNN training features took  21.8 seconds
[05:02:48.607525] Shape of final train features torch.Size([1024, 21760])
22it [00:04,  5.45it/s]
[05:02:53.280825] eval results (16): 0.7505555555555555
[05:02:53.282928] Starting KNN evaluation with K=5
0it [00:00, ?it/s][05:02:58.192556] Eval data mean (should be near 0): tensor(-0.0519)
[05:02:58.196521] Eval data std (should be near 1): tensor(0.9873)
85it [00:25,  3.32it/s]
[05:03:19.508151] distributed world size: 4
[05:03:19.512854] Grabbing all kNN training features took  26.2 seconds
[05:03:19.512914] Shape of final train features torch.Size([1024, 21760])
22it [00:02,  7.78it/s]
[05:03:23.011567] eval results (32): 0.9120370370370371
[05:03:23.014749] Starting KNN evaluation with K=5
0it [00:00, ?it/s][05:03:27.045041] Eval data mean (should be near 0): tensor(-0.0519)
[05:03:27.051554] Eval data std (should be near 1): tensor(0.9873)
85it [00:18,  4.62it/s]
[05:03:42.034873] distributed world size: 4
[05:03:42.037856] Grabbing all kNN training features took  19.0 seconds
[05:03:42.037906] Shape of final train features torch.Size([1024, 21760])
22it [00:05,  3.87it/s]
[05:03:48.729527] eval results (64): 0.9601851851851851
[05:03:48.729616] Training time 0:01:21

Hope this info helps.

Also, our checkpoint has a checksum 00acf88bd23ac9f02ac7e073b231e71705c63be5 generated via sha1sum, can you also check this? There is a slim chance that we uploaded the wrong weight.

For your specific question, res should be 224 / 64

RitwikGupta commented 1 year ago

The sha1sum for the released weight is also 00acf88bd23ac9f02ac7e073b231e71705c63be5 -- the weights are correct.

isaaccorley commented 1 year ago

@jacklishufan Could you upload the EuroSAT train/val/test splits you are using and explain how they were generated? They don't appear to exist in the mae/splits directory.

calebrob6 commented 1 year ago

I think I see a discrepancy -- Shape of final train features torch.Size([1024, 21760]) whereas the train/val/test set sizes in TorchGeo are 16200/5400/5400 (the same as from https://arxiv.org/abs/1911.06721).

calebrob6 commented 1 year ago

So combining the train and val sets (which gives me 21600 samples not 21760) and setting res=1 gives me an accuracy of 0.9565 which is closer. However, if I set res = 224/64 then I get a much worse result of 0.9183.

jacklishufan commented 1 year ago

Hi we pushed the eurosat split to our codebase. Hope this will help reproduce the result.

We also checked our gsd_ratio. The actual variable passed to gsd_ratio is 14,7,3.5 for our multiscale evaluation with 3.5 on 64x64 image . This should give you 0.9601 Knn result.

Example of calling the model

 features = net(
            inputs,
            input_res=torch.ones(len(inputs)).float().to(inputs.device) * gsd_ratio,
            knn_feats=True,
        )

Also, we use the model from models_mae for kNN evaluation while models_vit is for linear eval/fine-tuning. Tho in principle, there should be no difference but you may want to switch to this as a step to isolate problems

calebrob6 commented 1 year ago

I'm happy to report I can get "0.9605" with the models_mae model implementation and using the following for inference:

# images is a batch of Bx3x64x64 that has been channel-wise standardized using train statistics
images = F.interpolate(images.float(), (224,224), mode="bilinear")
with torch.inference_mode():
    features = model.forward(images, knn_feats=True, input_res=torch.ones(1)).cpu().numpy()

I'd love to include this in torchgeo, however models_mae has the dependency on the very early timm==0.3.2 and an early version of numpy. Any chance that you all would be able to re-factor this to remove these dependencies / conform to a standard API? E.g. I'd love to be able to do something like:

model = scalemae_vit_large_patch16_encoder(input_size=224, res=1.0)
model.load_state_dict(torch.load("your_weights.chpt"))

model.forward_features(torch.randn(1,3,224,224))

(p.s. thanks a ton for your help in this @jacklishufan!)

RitwikGupta commented 1 year ago

@calebrob6 we just merged #2 which removes the dependency on timm==0.3.2 and fixes to work with newer versions of numpy. This should work to integrate into TorchGeo now.

CarlosGomes98 commented 1 month ago

@calebrob6 if I could just confirm, in the end you achieved the best accuracy by setting res to 1, despite resizing the images through interpolation by ~4x?

Even after reading this thread, its not clear at all to me how the res value should be set. I thought it would be related somehow to the GSD of the original image.

But from what I understood from this thread, if I have one image captured at 10m GSD, and another at 30m, but both are 64x64 pixels. If I resize them both to 224x224, both should use a res value of 3.5? I find it hard to understand the value of res in this case...