facebookresearch / vissl

VISSL is FAIR's library of extensible, modular and scalable components for SOTA Self-Supervised Learning with images.
https://vissl.ai
MIT License
3.24k stars 330 forks source link

Cannot reproduce colorization results #518

Open YannDubs opened 2 years ago

YannDubs commented 2 years ago

Hi,

I'm trying to evaluate many of the pretrained models available in VISSL (thanks for that!!!).

I was able to reproduce all the ones I tried (rotnet, jigsaw, dino, pirl, npid, clusterfit, simclr...) besides colorization. I wonder what I am doing wrong.

Here's the code I'm using

from vissl.config import AttrDict
from vissl.models.trunks.resnext import ResNeXt
from torch.hub import load_state_dict_from_url

# default parameters while avoing configs
dflt_rn_cfg = AttrDict({"INPUT_TYPE": "rgb",
                        "ACTIVATION_CHECKPOINTING": {"USE_ACTIVATION_CHECKPOINTING": False,
                                                     "NUM_ACTIVATION_CHECKPOINTING_SPLITS": 2},
                        "TRUNK": {"RESNETS": {"DEPTH": 50, "WIDTH_MULTIPLIER": 1, "NORM": "BatchNorm",
                                              "GROUPNORM_GROUPS": 32, "STANDARDIZE_CONVOLUTIONS": False,
                                              "GROUPS": 1, "ZERO_INIT_RESIDUAL": False,
                                              "WIDTH_PER_GROUP": 64, "LAYER4_STRIDE": 2}}})

# changes for colorization
dflt_rn_cfg.INPUT_TYPE = "lab"
dflt_rn_cfg.TRUNK.RESNETS.LAYER4_STRIDE = 1

encoder = ResNeXt(dflt_rn_cfg, "resnet")
encoder.feat_eval_mapping = None # only return the last representation

# load model
URL = "https://dl.fbaipublicfiles.com/vissl/model_zoo/converted_vissl_rn50_colorization_in1k_goyal19.torch"
state_dict = load_state_dict_from_url(url=URL, map_location="cpu")
state_dict = state_dict["model_state_dict"]
encoder.load_state_dict(state_dict, strict=False)

# transform for colorization
preprocess = transforms.Compose([
                transforms.Resize(256),
                transforms.CenterCrop(224),
                ImgPil2LabTensor()
            ])

I then use preprocess to transform the data, encoder for featurizing the input, and fit a linear classifier on ImageNet . Such code worked for the dozen of model I have tried, but for colorization I get around 14% linear probing which is much lower than reported. Given that colorization seems to be a model with many changes necessary (different layer stride, input type, and transform), I wonder if I forgot any other necessary change ?

Note that using strict=True does give the error: Unexpected key(s) in state_dict: "_feature_blocks.data.ab.ss.bias", "_feature_blocks.data.ab.ss.weight", "_feature_blocks.fc1.bias", "_feature_blocks.fc1.weight". but that seems to be expected I believe ? If I understand correctly fc1 is the head only used at training time, but I'm not sure about _feature_blocks.data.ab.ss.weight as both during training and inference we only use the L channel.

Sorry for the vague question, but I have been looking for hours, and thought that you might just directly know what is missing.

Thank you for your help and library!!

QuentinDuval commented 2 years ago

Hi @YannDubs,

First of all, thank you for using VISSL :)

From what you are reporting, it seems indeed that the weights are not aligned with the model to evaluate. So this is where we must be looking.

I popped up a quick Jupyter notebook, downloaded the weights at the URL you mentioned: https://dl.fbaipublicfiles.com/vissl/model_zoo/converted_vissl_rn50_rotnet_16kclusters_in1k_ep105.torch

And then I loaded the checkpoint itself:

cp_color = torch.load("/path/to/converted_vissl_rn50_rotnet_16kclusters_in1k_ep105.torch")
count = 0
for k, v in cp_color["classy_state_dict"]["base_model"]["model"]["trunk"].items():
    count += v.numel()
    if "ab.ss" in k:
        print(k)
print(count)

This outputs a number of parameters consistent with a RN50, and I cannot see any "_feature_blocks.data.ab.ss.weight" in the downloaded checkpoint. So there is something to investigate there. The other point is that the format of VISSL checkpoint is not quite the same as the format supported by nn.Module so it might also be that.

Could you please tell what's the content in terms of keys of the checkpoint you have? I am mostly interested in the hierarchy of the dictionary, for instance:

checkpoint = {
   "classy_state_dict": {
      "base_model": {
          # etc
      }
   }
}
YannDubs commented 2 years ago

Hi @QuentinDuval

Sorry I copy-pasted the wrong link, I meant colorization (rotnet works as expected): https://dl.fbaipublicfiles.com/vissl/model_zoo/converted_vissl_rn50_colorization_in1k_goyal19.torch

The keys of the checkpoint are the following:

 state_dict = {
      "model_state_dict" : {
      '_feature_blocks.layer3.4.bn1.bias', '_feature_blocks.layer2.1.bn2.running_mean', '_feature_blocks.layer1.1.conv3.weight', '_feature_blocks.layer4.2.bn3.weight', '_feature_blocks.layer4.0.bn1.running_mean', '_feature_blocks.layer2.1.bn1.bias', '_feature_blocks.layer3.4.conv1.weight', '_feature_blocks.layer1.2.conv2.weight', '_feature_blocks.layer2.0.downsample.1.weight', '_feature_blocks.layer3.0.bn1.bias', '_feature_blocks.layer3.3.bn2.weight', '_feature_blocks.layer3.3.bn1.running_var', '_feature_blocks.layer1.1.conv2.weight', '_feature_blocks.layer1.0.bn3.bias', '_feature_blocks.layer3.5.bn2.running_var', '_feature_blocks.layer4.1.bn3.bias', '_feature_blocks.layer3.2.bn1.running_mean', '_feature_blocks.layer3.0.bn3.bias', '_feature_blocks.layer3.2.conv2.weight', '_feature_blocks.layer1.2.bn2.weight', '_feature_blocks.layer4.0.downsample.1.running_var', '_feature_blocks.layer2.1.bn1.weight', '_feature_blocks.layer1.0.bn1.running_mean', '_feature_blocks.layer2.3.bn3.weight', '_feature_blocks.layer3.0.bn3.weight', '_feature_blocks.layer2.2.bn1.running_mean', '_feature_blocks.layer1.0.bn3.weight', '_feature_blocks.layer2.2.bn3.running_mean', '_feature_blocks.fc1.weight', '_feature_blocks.layer2.3.bn2.running_var', '_feature_blocks.layer1.2.bn1.running_var', '_feature_blocks.layer1.0.conv3.weight', '_feature_blocks.layer1.2.bn3.bias', '_feature_blocks.layer2.2.bn1.running_var', '_feature_blocks.layer2.0.downsample.1.running_var', '_feature_blocks.layer3.5.bn3.running_var', '_feature_blocks.layer3.3.bn3.weight', '_feature_blocks.layer3.2.bn2.running_mean', '_feature_blocks.layer1.1.bn2.running_var', '_feature_blocks.layer2.1.bn2.running_var', '_feature_blocks.layer3.5.bn3.running_mean', '_feature_blocks.layer4.1.bn1.weight', '_feature_blocks.layer3.3.bn3.bias', '_feature_blocks.layer4.2.bn2.running_var', '_feature_blocks.layer3.4.bn2.running_var', '_feature_blocks.layer1.2.bn2.bias', '_feature_blocks.layer1.0.downsample.1.weight', '_feature_blocks.layer3.1.bn1.bias', '_feature_blocks.layer1.2.bn1.running_mean', '_feature_blocks.layer3.0.bn1.weight', '_feature_blocks.layer2.0.bn3.running_mean', '_feature_blocks.layer2.1.conv2.weight', '_feature_blocks.layer4.0.bn1.bias', '_feature_blocks.layer4.2.bn2.weight', '_feature_blocks.layer1.0.bn2.weight', '_feature_blocks.layer1.2.bn3.running_var', '_feature_blocks.layer2.0.conv3.weight', '_feature_blocks.layer3.5.conv3.weight', '_feature_blocks.layer2.1.bn3.running_mean', '_feature_blocks.layer2.0.bn2.bias', '_feature_blocks.layer4.1.bn2.running_var', '_feature_blocks.layer2.3.bn1.weight', '_feature_blocks.layer1.2.bn2.running_var', '_feature_blocks.layer3.4.bn2.weight', '_feature_blocks.layer2.2.conv3.weight', '_feature_blocks.layer3.0.bn3.running_mean', '_feature_blocks.layer3.1.conv1.weight', '_feature_blocks.layer3.3.bn2.bias', '_feature_blocks.layer4.0.bn3.running_var', '_feature_blocks.layer3.5.bn2.bias', '_feature_blocks.data.ab.ss.bias', '_feature_blocks.layer1.2.conv3.weight', '_feature_blocks.layer4.1.bn2.running_mean', '_feature_blocks.layer3.3.conv1.weight', '_feature_blocks.layer3.4.bn1.running_mean', '_feature_blocks.layer4.0.downsample.1.bias', '_feature_blocks.layer3.5.bn1.running_mean', '_feature_blocks.layer2.2.bn2.weight', '_feature_blocks.layer3.3.bn3.running_var', '_feature_blocks.layer3.2.bn3.weight', '_feature_blocks.layer3.5.bn3.weight', '_feature_blocks.layer4.2.bn1.weight', '_feature_blocks.layer2.0.bn2.running_var', '_feature_blocks.layer3.4.bn3.running_var', '_feature_blocks.layer2.3.bn2.bias', '_feature_blocks.layer1.1.bn2.running_mean', '_feature_blocks.layer3.4.bn3.bias', '_feature_blocks.layer4.2.bn1.running_mean', '_feature_blocks.layer2.3.conv3.weight', '_feature_blocks.layer2.3.bn1.bias', '_feature_blocks.fc1.bias', '_feature_blocks.layer4.1.bn2.weight', '_feature_blocks.layer1.1.bn3.running_mean', '_feature_blocks.layer3.0.bn1.running_var', '_feature_blocks.layer4.2.conv2.weight', '_feature_blocks.layer2.2.bn1.bias', '_feature_blocks.layer2.1.bn3.running_var', '_feature_blocks.layer3.0.bn2.weight', '_feature_blocks.layer4.2.bn1.running_var', '_feature_blocks.layer4.0.bn3.bias', '_feature_blocks.layer3.4.conv3.weight', '_feature_blocks.layer1.2.bn3.running_mean', '_feature_blocks.layer3.2.bn3.running_var', '_feature_blocks.layer1.0.downsample.1.running_mean', '_feature_blocks.layer2.3.bn2.weight', '_feature_blocks.layer2.3.bn3.bias', '_feature_blocks.layer2.1.bn1.running_mean', '_feature_blocks.layer3.1.bn3.weight', '_feature_blocks.layer3.2.conv1.weight', '_feature_blocks.layer3.4.bn2.running_mean', '_feature_blocks.layer1.1.conv1.weight', '_feature_blocks.layer3.1.conv3.weight', '_feature_blocks.layer1.0.bn2.bias', '_feature_blocks.layer4.0.conv1.weight', '_feature_blocks.layer1.0.bn2.running_mean', '_feature_blocks.layer2.2.bn3.bias', '_feature_blocks.layer3.1.conv2.weight', '_feature_blocks.layer4.1.conv3.weight', '_feature_blocks.layer2.0.downsample.0.weight', '_feature_blocks.layer2.1.conv3.weight', '_feature_blocks.layer1.0.bn2.running_var', '_feature_blocks.layer2.0.downsample.1.bias', '_feature_blocks.layer2.0.bn1.running_var', '_feature_blocks.layer2.0.bn1.bias', '_feature_blocks.layer2.1.bn3.bias', '_feature_blocks.layer2.0.bn3.bias', '_feature_blocks.layer3.4.bn1.running_var', '_feature_blocks.layer2.0.bn2.weight', '_feature_blocks.layer3.5.bn1.running_var', '_feature_blocks.layer3.0.bn2.running_mean', '_feature_blocks.layer3.0.conv3.weight', '_feature_blocks.layer1.1.bn2.bias', '_feature_blocks.layer2.1.bn2.bias', '_feature_blocks.layer2.3.bn2.running_mean', '_feature_blocks.layer1.2.bn3.weight', '_feature_blocks.data.ab.ss.weight', '_feature_blocks.layer2.1.bn2.weight', '_feature_blocks.layer1.1.bn1.running_mean', '_feature_blocks.layer3.3.bn3.running_mean', '_feature_blocks.layer2.1.bn1.running_var', '_feature_blocks.layer4.1.conv1.weight', '_feature_blocks.bn1.weight', '_feature_blocks.layer3.3.conv2.weight', '_feature_blocks.layer3.4.bn2.bias', '_feature_blocks.layer4.2.conv1.weight', '_feature_blocks.layer3.0.bn2.running_var', '_feature_blocks.layer3.1.bn3.running_mean', '_feature_blocks.layer2.0.bn3.weight', '_feature_blocks.layer4.0.bn1.weight', '_feature_blocks.layer4.0.bn3.weight', '_feature_blocks.layer1.0.downsample.1.bias', '_feature_blocks.layer1.2.bn1.weight', '_feature_blocks.layer2.0.conv2.weight', '_feature_blocks.layer2.0.bn3.running_var', '_feature_blocks.layer4.1.bn1.bias', '_feature_blocks.layer3.3.bn1.bias', '_feature_blocks.layer4.0.bn2.running_var', '_feature_blocks.layer4.0.bn2.bias', '_feature_blocks.layer2.3.bn3.running_mean', '_feature_blocks.layer3.5.bn2.weight', '_feature_blocks.layer4.2.bn1.bias', '_feature_blocks.layer3.0.downsample.0.weight', '_feature_blocks.layer2.2.bn1.weight', '_feature_blocks.layer1.0.downsample.0.weight', '_feature_blocks.layer4.1.bn1.running_var', '_feature_blocks.layer3.3.bn2.running_var', '_feature_blocks.layer1.0.bn1.weight', '_feature_blocks.layer3.5.bn3.bias', '_feature_blocks.layer4.2.bn3.running_var', '_feature_blocks.layer3.3.bn1.running_mean', '_feature_blocks.layer1.0.bn3.running_var', '_feature_blocks.layer4.1.conv2.weight', '_feature_blocks.layer3.0.conv2.weight', '_feature_blocks.layer3.1.bn1.running_mean', '_feature_blocks.layer2.0.bn1.weight', '_feature_blocks.layer3.1.bn2.weight', '_feature_blocks.layer4.2.bn2.bias', '_feature_blocks.layer3.2.bn1.running_var', '_feature_blocks.layer2.3.bn1.running_var', '_feature_blocks.layer3.2.bn2.running_var', '_feature_blocks.layer3.2.bn3.running_mean', '_feature_blocks.layer3.0.downsample.1.running_mean', '_feature_blocks.layer1.0.bn1.running_var', '_feature_blocks.layer3.1.bn2.running_mean', '_feature_blocks.layer1.2.conv1.weight', '_feature_blocks.layer3.2.bn2.weight', '_feature_blocks.layer4.0.bn2.weight', '_feature_blocks.layer2.3.conv2.weight', '_feature_blocks.layer2.1.conv1.weight', '_feature_blocks.layer4.1.bn1.running_mean', '_feature_blocks.layer4.2.bn3.running_mean', '_feature_blocks.layer3.1.bn1.running_var', '_feature_blocks.layer3.1.bn3.running_var', '_feature_blocks.layer1.1.bn1.running_var', '_feature_blocks.layer1.0.conv1.weight', '_feature_blocks.layer3.0.bn2.bias', '_feature_blocks.layer1.1.bn3.weight', '_feature_blocks.layer3.2.bn1.weight', '_feature_blocks.layer3.2.conv3.weight', '_feature_blocks.layer3.2.bn3.bias', '_feature_blocks.layer2.3.conv1.weight', '_feature_blocks.layer4.0.downsample.1.running_mean', '_feature_blocks.layer4.0.bn2.running_mean', '_feature_blocks.layer2.3.bn3.running_var', '_feature_blocks.layer3.3.conv3.weight', '_feature_blocks.layer1.2.bn1.bias', '_feature_blocks.bn1.running_mean', '_feature_blocks.layer2.2.conv2.weight', '_feature_blocks.layer1.0.bn1.bias', '_feature_blocks.layer2.2.conv1.weight', '_feature_blocks.layer2.0.downsample.1.running_mean', '_feature_blocks.layer2.3.bn1.running_mean', '_feature_blocks.layer1.1.bn1.bias', '_feature_blocks.layer1.0.conv2.weight', '_feature_blocks.layer1.1.bn1.weight', '_feature_blocks.layer3.1.bn2.running_var', '_feature_blocks.layer3.3.bn2.running_mean', '_feature_blocks.layer3.5.bn1.bias', '_feature_blocks.layer4.2.conv3.weight', '_feature_blocks.layer3.5.bn2.running_mean', '_feature_blocks.layer3.3.bn1.weight', '_feature_blocks.layer2.2.bn2.running_mean', '_feature_blocks.layer4.0.downsample.1.weight', '_feature_blocks.layer3.0.conv1.weight', '_feature_blocks.layer4.1.bn2.bias', '_feature_blocks.layer2.2.bn2.bias', '_feature_blocks.layer2.2.bn3.weight', '_feature_blocks.layer4.2.bn2.running_mean', '_feature_blocks.layer3.1.bn1.weight', '_feature_blocks.layer3.0.bn1.running_mean', '_feature_blocks.layer2.0.bn2.running_mean', '_feature_blocks.conv1.weight', '_feature_blocks.layer3.4.bn3.weight', '_feature_blocks.bn1.bias', '_feature_blocks.layer3.0.downsample.1.weight', '_feature_blocks.layer4.0.bn1.running_var', '_feature_blocks.layer3.0.downsample.1.bias', '_feature_blocks.layer2.1.bn3.weight', '_feature_blocks.layer2.2.bn3.running_var', '_feature_blocks.layer3.4.conv2.weight', '_feature_blocks.layer3.2.bn1.bias', '_feature_blocks.layer1.1.bn2.weight', '_feature_blocks.layer3.1.bn2.bias', '_feature_blocks.layer4.0.downsample.0.weight', '_feature_blocks.layer1.2.bn2.running_mean', '_feature_blocks.layer3.5.bn1.weight', '_feature_blocks.layer4.0.conv2.weight', '_feature_blocks.layer4.0.bn3.running_mean', '_feature_blocks.layer3.0.downsample.1.running_var', '_feature_blocks.layer4.2.bn3.bias', '_feature_blocks.layer3.4.bn3.running_mean', '_feature_blocks.layer4.1.bn3.running_mean', '_feature_blocks.layer1.0.downsample.1.running_var', '_feature_blocks.layer1.1.bn3.running_var', '_feature_blocks.layer1.0.bn3.running_mean', '_feature_blocks.layer3.0.bn3.running_var', '_feature_blocks.layer3.4.bn1.weight', '_feature_blocks.layer2.0.bn1.running_mean', '_feature_blocks.layer1.1.bn3.bias', '_feature_blocks.layer4.0.conv3.weight', '_feature_blocks.layer4.1.bn3.weight', '_feature_blocks.layer4.1.bn3.running_var', '_feature_blocks.layer2.2.bn2.running_var', '_feature_blocks.layer3.2.bn2.bias', '_feature_blocks.layer3.1.bn3.bias', '_feature_blocks.bn1.running_var', '_feature_blocks.layer3.5.conv2.weight', '_feature_blocks.layer2.0.conv1.weight', '_feature_blocks.layer3.5.conv1.weight'
      }
 }

The number of parameters is 46 632 061

And the two strange keys I have are _feature_blocks.data.ab.ss.bias, _feature_blocks.data.ab.ss.weight which I think comes from the conversion from Caffe2 (likely for the channels ab in the Lab colorization ?), although I haven't been able to exactly find it in the original code.

iseessel commented 2 years ago

@YannDubs We will have to look at this more in-depth. We are very busy post-project deadline. Sorry for the delay.

YannDubs commented 2 years ago

Hi @QuentinDuval, just checking up on that. Is there any news for loading/evaluating the colorization models?