Open YannDubs opened 2 years ago
Hi @YannDubs,
First of all, thank you for using VISSL :)
From what you are reporting, it seems indeed that the weights are not aligned with the model to evaluate. So this is where we must be looking.
I popped up a quick Jupyter notebook, downloaded the weights at the URL you mentioned: https://dl.fbaipublicfiles.com/vissl/model_zoo/converted_vissl_rn50_rotnet_16kclusters_in1k_ep105.torch
And then I loaded the checkpoint itself:
cp_color = torch.load("/path/to/converted_vissl_rn50_rotnet_16kclusters_in1k_ep105.torch")
count = 0
for k, v in cp_color["classy_state_dict"]["base_model"]["model"]["trunk"].items():
count += v.numel()
if "ab.ss" in k:
print(k)
print(count)
This outputs a number of parameters consistent with a RN50, and I cannot see any "_feature_blocks.data.ab.ss.weight" in the downloaded checkpoint. So there is something to investigate there. The other point is that the format of VISSL checkpoint is not quite the same as the format supported by nn.Module
so it might also be that.
Could you please tell what's the content in terms of keys of the checkpoint you have? I am mostly interested in the hierarchy of the dictionary, for instance:
checkpoint = {
"classy_state_dict": {
"base_model": {
# etc
}
}
}
Hi @QuentinDuval
Sorry I copy-pasted the wrong link, I meant colorization (rotnet works as expected): https://dl.fbaipublicfiles.com/vissl/model_zoo/converted_vissl_rn50_colorization_in1k_goyal19.torch
The keys of the checkpoint are the following:
state_dict = {
"model_state_dict" : {
'_feature_blocks.layer3.4.bn1.bias', '_feature_blocks.layer2.1.bn2.running_mean', '_feature_blocks.layer1.1.conv3.weight', '_feature_blocks.layer4.2.bn3.weight', '_feature_blocks.layer4.0.bn1.running_mean', '_feature_blocks.layer2.1.bn1.bias', '_feature_blocks.layer3.4.conv1.weight', '_feature_blocks.layer1.2.conv2.weight', '_feature_blocks.layer2.0.downsample.1.weight', '_feature_blocks.layer3.0.bn1.bias', '_feature_blocks.layer3.3.bn2.weight', '_feature_blocks.layer3.3.bn1.running_var', '_feature_blocks.layer1.1.conv2.weight', '_feature_blocks.layer1.0.bn3.bias', '_feature_blocks.layer3.5.bn2.running_var', '_feature_blocks.layer4.1.bn3.bias', '_feature_blocks.layer3.2.bn1.running_mean', '_feature_blocks.layer3.0.bn3.bias', '_feature_blocks.layer3.2.conv2.weight', '_feature_blocks.layer1.2.bn2.weight', '_feature_blocks.layer4.0.downsample.1.running_var', '_feature_blocks.layer2.1.bn1.weight', '_feature_blocks.layer1.0.bn1.running_mean', '_feature_blocks.layer2.3.bn3.weight', '_feature_blocks.layer3.0.bn3.weight', '_feature_blocks.layer2.2.bn1.running_mean', '_feature_blocks.layer1.0.bn3.weight', '_feature_blocks.layer2.2.bn3.running_mean', '_feature_blocks.fc1.weight', '_feature_blocks.layer2.3.bn2.running_var', '_feature_blocks.layer1.2.bn1.running_var', '_feature_blocks.layer1.0.conv3.weight', '_feature_blocks.layer1.2.bn3.bias', '_feature_blocks.layer2.2.bn1.running_var', '_feature_blocks.layer2.0.downsample.1.running_var', '_feature_blocks.layer3.5.bn3.running_var', '_feature_blocks.layer3.3.bn3.weight', '_feature_blocks.layer3.2.bn2.running_mean', '_feature_blocks.layer1.1.bn2.running_var', '_feature_blocks.layer2.1.bn2.running_var', '_feature_blocks.layer3.5.bn3.running_mean', '_feature_blocks.layer4.1.bn1.weight', '_feature_blocks.layer3.3.bn3.bias', '_feature_blocks.layer4.2.bn2.running_var', '_feature_blocks.layer3.4.bn2.running_var', '_feature_blocks.layer1.2.bn2.bias', '_feature_blocks.layer1.0.downsample.1.weight', '_feature_blocks.layer3.1.bn1.bias', '_feature_blocks.layer1.2.bn1.running_mean', '_feature_blocks.layer3.0.bn1.weight', '_feature_blocks.layer2.0.bn3.running_mean', '_feature_blocks.layer2.1.conv2.weight', '_feature_blocks.layer4.0.bn1.bias', '_feature_blocks.layer4.2.bn2.weight', '_feature_blocks.layer1.0.bn2.weight', '_feature_blocks.layer1.2.bn3.running_var', '_feature_blocks.layer2.0.conv3.weight', '_feature_blocks.layer3.5.conv3.weight', '_feature_blocks.layer2.1.bn3.running_mean', '_feature_blocks.layer2.0.bn2.bias', '_feature_blocks.layer4.1.bn2.running_var', '_feature_blocks.layer2.3.bn1.weight', '_feature_blocks.layer1.2.bn2.running_var', '_feature_blocks.layer3.4.bn2.weight', '_feature_blocks.layer2.2.conv3.weight', '_feature_blocks.layer3.0.bn3.running_mean', '_feature_blocks.layer3.1.conv1.weight', '_feature_blocks.layer3.3.bn2.bias', '_feature_blocks.layer4.0.bn3.running_var', '_feature_blocks.layer3.5.bn2.bias', '_feature_blocks.data.ab.ss.bias', '_feature_blocks.layer1.2.conv3.weight', '_feature_blocks.layer4.1.bn2.running_mean', '_feature_blocks.layer3.3.conv1.weight', '_feature_blocks.layer3.4.bn1.running_mean', '_feature_blocks.layer4.0.downsample.1.bias', '_feature_blocks.layer3.5.bn1.running_mean', '_feature_blocks.layer2.2.bn2.weight', '_feature_blocks.layer3.3.bn3.running_var', '_feature_blocks.layer3.2.bn3.weight', '_feature_blocks.layer3.5.bn3.weight', '_feature_blocks.layer4.2.bn1.weight', '_feature_blocks.layer2.0.bn2.running_var', '_feature_blocks.layer3.4.bn3.running_var', '_feature_blocks.layer2.3.bn2.bias', '_feature_blocks.layer1.1.bn2.running_mean', '_feature_blocks.layer3.4.bn3.bias', '_feature_blocks.layer4.2.bn1.running_mean', '_feature_blocks.layer2.3.conv3.weight', '_feature_blocks.layer2.3.bn1.bias', '_feature_blocks.fc1.bias', '_feature_blocks.layer4.1.bn2.weight', '_feature_blocks.layer1.1.bn3.running_mean', '_feature_blocks.layer3.0.bn1.running_var', '_feature_blocks.layer4.2.conv2.weight', '_feature_blocks.layer2.2.bn1.bias', '_feature_blocks.layer2.1.bn3.running_var', '_feature_blocks.layer3.0.bn2.weight', '_feature_blocks.layer4.2.bn1.running_var', '_feature_blocks.layer4.0.bn3.bias', '_feature_blocks.layer3.4.conv3.weight', '_feature_blocks.layer1.2.bn3.running_mean', '_feature_blocks.layer3.2.bn3.running_var', '_feature_blocks.layer1.0.downsample.1.running_mean', '_feature_blocks.layer2.3.bn2.weight', '_feature_blocks.layer2.3.bn3.bias', '_feature_blocks.layer2.1.bn1.running_mean', '_feature_blocks.layer3.1.bn3.weight', '_feature_blocks.layer3.2.conv1.weight', '_feature_blocks.layer3.4.bn2.running_mean', '_feature_blocks.layer1.1.conv1.weight', '_feature_blocks.layer3.1.conv3.weight', '_feature_blocks.layer1.0.bn2.bias', '_feature_blocks.layer4.0.conv1.weight', '_feature_blocks.layer1.0.bn2.running_mean', '_feature_blocks.layer2.2.bn3.bias', '_feature_blocks.layer3.1.conv2.weight', '_feature_blocks.layer4.1.conv3.weight', '_feature_blocks.layer2.0.downsample.0.weight', '_feature_blocks.layer2.1.conv3.weight', '_feature_blocks.layer1.0.bn2.running_var', '_feature_blocks.layer2.0.downsample.1.bias', '_feature_blocks.layer2.0.bn1.running_var', '_feature_blocks.layer2.0.bn1.bias', '_feature_blocks.layer2.1.bn3.bias', '_feature_blocks.layer2.0.bn3.bias', '_feature_blocks.layer3.4.bn1.running_var', '_feature_blocks.layer2.0.bn2.weight', '_feature_blocks.layer3.5.bn1.running_var', '_feature_blocks.layer3.0.bn2.running_mean', '_feature_blocks.layer3.0.conv3.weight', '_feature_blocks.layer1.1.bn2.bias', '_feature_blocks.layer2.1.bn2.bias', '_feature_blocks.layer2.3.bn2.running_mean', '_feature_blocks.layer1.2.bn3.weight', '_feature_blocks.data.ab.ss.weight', '_feature_blocks.layer2.1.bn2.weight', '_feature_blocks.layer1.1.bn1.running_mean', '_feature_blocks.layer3.3.bn3.running_mean', '_feature_blocks.layer2.1.bn1.running_var', '_feature_blocks.layer4.1.conv1.weight', '_feature_blocks.bn1.weight', '_feature_blocks.layer3.3.conv2.weight', '_feature_blocks.layer3.4.bn2.bias', '_feature_blocks.layer4.2.conv1.weight', '_feature_blocks.layer3.0.bn2.running_var', '_feature_blocks.layer3.1.bn3.running_mean', '_feature_blocks.layer2.0.bn3.weight', '_feature_blocks.layer4.0.bn1.weight', '_feature_blocks.layer4.0.bn3.weight', '_feature_blocks.layer1.0.downsample.1.bias', '_feature_blocks.layer1.2.bn1.weight', '_feature_blocks.layer2.0.conv2.weight', '_feature_blocks.layer2.0.bn3.running_var', '_feature_blocks.layer4.1.bn1.bias', '_feature_blocks.layer3.3.bn1.bias', '_feature_blocks.layer4.0.bn2.running_var', '_feature_blocks.layer4.0.bn2.bias', '_feature_blocks.layer2.3.bn3.running_mean', '_feature_blocks.layer3.5.bn2.weight', '_feature_blocks.layer4.2.bn1.bias', '_feature_blocks.layer3.0.downsample.0.weight', '_feature_blocks.layer2.2.bn1.weight', '_feature_blocks.layer1.0.downsample.0.weight', '_feature_blocks.layer4.1.bn1.running_var', '_feature_blocks.layer3.3.bn2.running_var', '_feature_blocks.layer1.0.bn1.weight', '_feature_blocks.layer3.5.bn3.bias', '_feature_blocks.layer4.2.bn3.running_var', '_feature_blocks.layer3.3.bn1.running_mean', '_feature_blocks.layer1.0.bn3.running_var', '_feature_blocks.layer4.1.conv2.weight', '_feature_blocks.layer3.0.conv2.weight', '_feature_blocks.layer3.1.bn1.running_mean', '_feature_blocks.layer2.0.bn1.weight', '_feature_blocks.layer3.1.bn2.weight', '_feature_blocks.layer4.2.bn2.bias', '_feature_blocks.layer3.2.bn1.running_var', '_feature_blocks.layer2.3.bn1.running_var', '_feature_blocks.layer3.2.bn2.running_var', '_feature_blocks.layer3.2.bn3.running_mean', '_feature_blocks.layer3.0.downsample.1.running_mean', '_feature_blocks.layer1.0.bn1.running_var', '_feature_blocks.layer3.1.bn2.running_mean', '_feature_blocks.layer1.2.conv1.weight', '_feature_blocks.layer3.2.bn2.weight', '_feature_blocks.layer4.0.bn2.weight', '_feature_blocks.layer2.3.conv2.weight', '_feature_blocks.layer2.1.conv1.weight', '_feature_blocks.layer4.1.bn1.running_mean', '_feature_blocks.layer4.2.bn3.running_mean', '_feature_blocks.layer3.1.bn1.running_var', '_feature_blocks.layer3.1.bn3.running_var', '_feature_blocks.layer1.1.bn1.running_var', '_feature_blocks.layer1.0.conv1.weight', '_feature_blocks.layer3.0.bn2.bias', '_feature_blocks.layer1.1.bn3.weight', '_feature_blocks.layer3.2.bn1.weight', '_feature_blocks.layer3.2.conv3.weight', '_feature_blocks.layer3.2.bn3.bias', '_feature_blocks.layer2.3.conv1.weight', '_feature_blocks.layer4.0.downsample.1.running_mean', '_feature_blocks.layer4.0.bn2.running_mean', '_feature_blocks.layer2.3.bn3.running_var', '_feature_blocks.layer3.3.conv3.weight', '_feature_blocks.layer1.2.bn1.bias', '_feature_blocks.bn1.running_mean', '_feature_blocks.layer2.2.conv2.weight', '_feature_blocks.layer1.0.bn1.bias', '_feature_blocks.layer2.2.conv1.weight', '_feature_blocks.layer2.0.downsample.1.running_mean', '_feature_blocks.layer2.3.bn1.running_mean', '_feature_blocks.layer1.1.bn1.bias', '_feature_blocks.layer1.0.conv2.weight', '_feature_blocks.layer1.1.bn1.weight', '_feature_blocks.layer3.1.bn2.running_var', '_feature_blocks.layer3.3.bn2.running_mean', '_feature_blocks.layer3.5.bn1.bias', '_feature_blocks.layer4.2.conv3.weight', '_feature_blocks.layer3.5.bn2.running_mean', '_feature_blocks.layer3.3.bn1.weight', '_feature_blocks.layer2.2.bn2.running_mean', '_feature_blocks.layer4.0.downsample.1.weight', '_feature_blocks.layer3.0.conv1.weight', '_feature_blocks.layer4.1.bn2.bias', '_feature_blocks.layer2.2.bn2.bias', '_feature_blocks.layer2.2.bn3.weight', '_feature_blocks.layer4.2.bn2.running_mean', '_feature_blocks.layer3.1.bn1.weight', '_feature_blocks.layer3.0.bn1.running_mean', '_feature_blocks.layer2.0.bn2.running_mean', '_feature_blocks.conv1.weight', '_feature_blocks.layer3.4.bn3.weight', '_feature_blocks.bn1.bias', '_feature_blocks.layer3.0.downsample.1.weight', '_feature_blocks.layer4.0.bn1.running_var', '_feature_blocks.layer3.0.downsample.1.bias', '_feature_blocks.layer2.1.bn3.weight', '_feature_blocks.layer2.2.bn3.running_var', '_feature_blocks.layer3.4.conv2.weight', '_feature_blocks.layer3.2.bn1.bias', '_feature_blocks.layer1.1.bn2.weight', '_feature_blocks.layer3.1.bn2.bias', '_feature_blocks.layer4.0.downsample.0.weight', '_feature_blocks.layer1.2.bn2.running_mean', '_feature_blocks.layer3.5.bn1.weight', '_feature_blocks.layer4.0.conv2.weight', '_feature_blocks.layer4.0.bn3.running_mean', '_feature_blocks.layer3.0.downsample.1.running_var', '_feature_blocks.layer4.2.bn3.bias', '_feature_blocks.layer3.4.bn3.running_mean', '_feature_blocks.layer4.1.bn3.running_mean', '_feature_blocks.layer1.0.downsample.1.running_var', '_feature_blocks.layer1.1.bn3.running_var', '_feature_blocks.layer1.0.bn3.running_mean', '_feature_blocks.layer3.0.bn3.running_var', '_feature_blocks.layer3.4.bn1.weight', '_feature_blocks.layer2.0.bn1.running_mean', '_feature_blocks.layer1.1.bn3.bias', '_feature_blocks.layer4.0.conv3.weight', '_feature_blocks.layer4.1.bn3.weight', '_feature_blocks.layer4.1.bn3.running_var', '_feature_blocks.layer2.2.bn2.running_var', '_feature_blocks.layer3.2.bn2.bias', '_feature_blocks.layer3.1.bn3.bias', '_feature_blocks.bn1.running_var', '_feature_blocks.layer3.5.conv2.weight', '_feature_blocks.layer2.0.conv1.weight', '_feature_blocks.layer3.5.conv1.weight'
}
}
The number of parameters is 46 632 061
And the two strange keys I have are _feature_blocks.data.ab.ss.bias
, _feature_blocks.data.ab.ss.weight
which I think comes from the conversion from Caffe2 (likely for the channels ab
in the Lab
colorization ?), although I haven't been able to exactly find it in the original code.
@YannDubs We will have to look at this more in-depth. We are very busy post-project deadline. Sorry for the delay.
Hi @QuentinDuval, just checking up on that. Is there any news for loading/evaluating the colorization models?
Hi,
I'm trying to evaluate many of the pretrained models available in VISSL (thanks for that!!!).
I was able to reproduce all the ones I tried (rotnet, jigsaw, dino, pirl, npid, clusterfit, simclr...) besides
colorization
. I wonder what I am doing wrong.Here's the code I'm using
I then use
preprocess
to transform the data,encoder
for featurizing the input, and fit a linear classifier on ImageNet . Such code worked for the dozen of model I have tried, but forcolorization
I get around 14% linear probing which is much lower than reported. Given that colorization seems to be a model with many changes necessary (different layer stride, input type, and transform), I wonder if I forgot any other necessary change ?Note that using
strict=True
does give the error:Unexpected key(s) in state_dict: "_feature_blocks.data.ab.ss.bias", "_feature_blocks.data.ab.ss.weight", "_feature_blocks.fc1.bias", "_feature_blocks.fc1.weight".
but that seems to be expected I believe ? If I understand correctly fc1 is the head only used at training time, but I'm not sure about_feature_blocks.data.ab.ss.weight
as both during training and inference we only use theL
channel.Sorry for the vague question, but I have been looking for hours, and thought that you might just directly know what is missing.
Thank you for your help and library!!