VISSL is FAIR's library of extensible, modular and scalable components for SOTA Self-Supervised Learning with images.
Bug in the architecture of ResNet-50-w2 #469

Open CharlieCheckpt opened 2 years ago

CharlieCheckpt commented 2 years ago

Hi vissl team ! Thank you for the great package.

I got a dimension error when running the example from the documentation to train MoCo with ResNet-50-w2 (2x wider ResNet-50).

This error seems to be due to a bug in the architecture of ResNet-50-w2. Indeed I compared it with the architecture of torchvision.models.wide_resnet50_2 and the architectures are different.

Looking at 3. below, one can see that in vissl, first layer of ResNet-50-w2. is : (conv1): Conv2d(3, 128, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)

In torchvision:

from torchvision.models import wide_resnet50_2

prints :

(conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False) ...

Instructions To Reproduce the 🐛 Bug:

  1. what changes you made (git diff) or what code you wrote None

  2. what exact command you run: I ran the command proposed in the documentation.

python tools/run_distributed_engines.py config=pretrain/moco/moco_1node_resnet \
    config.MODEL.TRUNK.NAME=resnet config.MODEL.TRUNK.RESNETS.DEPTH=50 \
  1. what you observed (including full logs):
INFO 2021-11-13 18:02:39,780 hydra_config.py: 132: Training with config:
 'TRAINER': {'TASK_NAME': 'self_supervision_task',
             'TRAIN_STEP_NAME': 'standard_train_step'},
 'VERBOSE': False}
| Processes:                                                                  |
|  GPU   GI   CI        PID   Type   Process name                  GPU Memory |
|        ID   ID                                                   Usage      |
|    0   N/A  N/A     30064      C   python                           1279MiB |

INFO 2021-11-13 18:02:59,769 state_update_hooks.py: 115: Starting phase 0 [train]
Traceback (most recent call last):
  File "run_distributed_engines.py", line 54, in <module>
  File "run_distributed_engines.py", line 40, in hydra_main
  File "/home/xxx/vissl/vissl/utils/distributed_launcher.py", line 164, in launch_distributed
    raise e
  File "/home/xxx/vissl/vissl/utils/distributed_launcher.py", line 150, in launch_distributed
  File "/home/xxx/vissl/vissl/utils/distributed_launcher.py", line 192, in _distributed_worker
  File "/home/xxx/vissl/vissl/engines/engine_registry.py", line 86, in run_engine
  File "/home/xxx/vissl/vissl/engines/train.py", line 39, in run_engine
  File "/home/xxx/vissl/vissl/engines/train.py", line 130, in train_main
  File "/home/xxx/vissl/vissl/trainer/trainer_main.py", line 211, in train
    raise e
  File "/home/xxx/vissl/vissl/trainer/trainer_main.py", line 193, in train
    task = train_step_fn(task)
  File "/home/xxx/vissl/vissl/trainer/train_steps/standard_train_step.py", line 143, in standard_train_step
    model_output = task.model(sample["input"])
  File "/home/xxx/.conda/envs/vissl/lib/python3.8/site-packages/classy_vision/models/classy_model.py", line 97, in __call__
    return self.forward(*args, **kwargs)
  File "/home/xxx/.conda/envs/vissl/lib/python3.8/site-packages/classy_vision/models/classy_model.py", line 111, in forward
    out = self.classy_model(*args, **kwargs)
  File "/home/xxx/.conda/envs/vissl/lib/python3.8/site-packages/torch/nn/modules/module.py", line 727, in _call_impl
    result = self.forward(*input, **kwargs)
  File "/home/xxx/vissl/vissl/models/base_ssl_model.py", line 180, in forward
    return self.single_input_forward(batch, self._output_feature_names, self.heads)
  File "/home/xxx/vissl/vissl/models/base_ssl_model.py", line 138, in single_input_forward
    return self.heads_forward(feats, heads)
  File "/home/xxx/vissl/vissl/models/base_ssl_model.py", line 159, in heads_forward
    output = head(output)
  File "/home/xxx/.conda/envs/vissl/lib/python3.8/site-packages/torch/nn/modules/module.py", line 727, in _call_impl
    result = self.forward(*input, **kwargs)
  File "/home/xxx/vissl/vissl/models/heads/mlp.py", line 111, in forward
    out = self.clf(batch)
  File "/home/xxx/.conda/envs/vissl/lib/python3.8/site-packages/torch/nn/modules/module.py", line 727, in _call_impl
    result = self.forward(*input, **kwargs)
  File "/home/xxx/.conda/envs/vissl/lib/python3.8/site-packages/torch/nn/modules/container.py", line 117, in forward
    input = module(input)
  File "/home/xxx/.conda/envs/vissl/lib/python3.8/site-packages/torch/nn/modules/module.py", line 727, in _call_impl
    result = self.forward(*input, **kwargs)
  File "/home/xxx/.conda/envs/vissl/lib/python3.8/site-packages/torch/nn/modules/linear.py", line 93, in forward
    return F.linear(input, self.weight, self.bias)
  File "/home/xxx/.conda/envs/vissl/lib/python3.8/site-packages/torch/nn/functional.py", line 1690, in linear
    ret = torch.addmm(bias, input, weight.t())
RuntimeError: mat1 dim 1 must match mat2 dim 0
  1. please simplify the steps as much as possible so they do not require additional resources to run, such as a private dataset.

Expected behavior:

If there are no obvious error in "what you observed" provided above, please tell us the expected behavior.


Provide your environment information using the following command:

iseessel commented 2 years ago

@CharlieCheckpt Thank you for bringing this to our attention!

After doing some digging, it seems like there may be some differences in the wide architectures. The wide architecture introduced: https://arxiv.org/abs/1605.07146 indeed only doubles the width of the residual layers and not the conv1.

But I have seen official checkpoints that also doubles the width of the conv1 layer. See for example BYOL: https://github.com/chigur/byol-convert/blob/main/resnet.py#L169. This script properly loads and converts the RESNET-200 2x BYOL model.

I am not sure if this is a confusion in the literature or a concious decision -- if so, I have not seen it explicit in what I've read.

@prigoyal @QuentinDuval Do you guys know anything more about this?

prigoyal commented 2 years ago

agree with above. It might be best to extend the resnext code in vissl to support different versions i.e. "wide_resnet50_2" as in https://github.com/pytorch/vision/blob/main/torchvision/models/resnet.py#L20