facebookresearch / suncet

Code to reproduce the results in the FAIR research papers "Semi-Supervised Learning of Visual Features by Non-Parametrically Predicting View Assignments with Support Samples" https://arxiv.org/abs/2104.13963 and "Supervision Accelerates Pre-training in Contrastive Semi-Supervised Learning of Visual Representations" https://arxiv.org/abs/2006.10803
MIT License
486 stars 67 forks source link

error when loading the optimizer's state dict #36

Closed amandalucasp closed 2 years ago

amandalucasp commented 2 years ago

Hi! I'm trying to load the provided checkpoints on paws_train.py, in order to run a few extra pre-training epochs with my custom dataset. For the model, I had to remove "module" from the state_dict keys, but for the optimizer, I get the following error:

ValueError: loaded state dict contains a parameter group that doesn't match the size of optimizer's group

Apparently, the number of parameters for the optimizer has changed. How to solve this? Currently, I'm simply not loading the optimizer and using a new one created from init_opt.

Thanks in advance!

MidoAssran commented 2 years ago

Hi @amandalucasp, are you only removing "module" from the keys (like below), or are you deleting the keys with the "module" prefix altogether?

You should remove "module" from the keys as follows:

pretrained_dict = {k.replace('module.', ''): v for k, v in checkpoint['encoder'].items()}
amandalucasp commented 2 years ago

I'm doing exactly as in you said, only removing the string "module" from the keys.

MidoAssran commented 2 years ago

ok I see; so for your custom optimizer, you just need to pass in two param groups instead of one (see here).

In short, one param group consists of all the bias and batch-norm params, and the other group consists of all the other params. The reason for creating two param groups is that we do not apply weight decay to the bias and batch-norm params during pre-training.

edit: Are you using a custom optimizer, or are you having this issue with the same LARS optimizer? edit: Another question, are you using the same model (i.e., from src/resnet.py), or are you using a custom implementation?

amandalucasp commented 2 years ago

Yeah I'm trying to use the same LARS optimizer, as in the original code; and the same model, didn't change anything from src/resnet.py.

MidoAssran commented 2 years ago

Ok do you know what line you're getting an error on? Is it here?

If so, could you log your optimizer and your optimizer state_dict in this thread to compare?

amandalucasp commented 2 years ago

Yep! Exactly at that line.

Optimizer:

SGD (
Parameter Group 0
    initial_lr: 0.4
    lr: 0.3
    momentum: 0.9
    nesterov: False
    weight_decay: 1e-06

Parameter Group 1
    LARS_exclude: True
    initial_lr: 0.4
    lr: 0.3
    momentum: 0.9
    nesterov: False
    weight_decay: 0
)

And for the state_dict:

# checkpoint['opt']['state'].keys()
dict_keys([0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127, 128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141, 142, 143, 144, 145, 146, 147, 148, 149, 150, 151, 152, 153, 154, 155, 156, 157, 158, 159, 160, 161, 162, 163, 164, 165, 166, 167, 168, 169, 170, 171, 172, 173, 174, 175, 176])
# checkpoint['opt']['param_groups']
[{'lr': 0.08391971905820779, 'momentum': 0.9, 'weight_decay': 1e-06, 'nesterov': False, 'initial_lr': 6.4, 'params': [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61]}, {'LARS_exclude': True, 'weight_decay': 0, 'lr': 0.08391971905820779, 'momentum': 0.9, 'nesterov': False, 'initial_lr': 6.4, 'params': [62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127, 128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141, 142, 143, 144, 145, 146, 147, 148, 149, 150, 151, 152, 153, 154, 155, 156, 157, 158, 159, 160, 161, 162, 163, 164, 165, 166, 167, 168, 169, 170, 171, 172, 173, 174, 175, 176]}]
MidoAssran commented 2 years ago

Ok I checked briefly and it seems to be working fine for me. I'm trying to dig more into it right now, but which checkpoint is giving you the error?

amandalucasp commented 2 years ago

Hm I see. It's this one: paws_imgnt_1percent_300ep.pth.tar

To be able to load the encoder, I changed https://github.com/facebookresearch/suncet/blob/731547d727b8c94d06c08a7848b4955de3a70cea/src/paws_train.py#L400 to

encoder_dict = {k.replace('module.', ''): v for k, v in checkpoint['encoder'].items()}
checkpoint['encoder'] = {k:v for k, v in encoder_dict.items() if 'pred' not in k}
# -- loading encoder
encoder.load_state_dict(checkpoint['encoder'])

But I'm thinking I don't need to worry about the optimizer and simply use a new one instead of the one from the checkpoint, since I'm not "resuming" the training. Just wanted to start learning from the ImageNet pre-trained weights

MidoAssran commented 2 years ago

Ok so you're loading a model pre-trained with a prediction head, but you're resuming training without a prediction head?

amandalucasp commented 2 years ago

Sorry, that was actually a mistake. For the models I trained from scratch, I had better results without the prediction head... But for the use case from this issue, since I’m trying to use your checkpoint, I should keep the prediction head. I just tested the code removing that second line of code and it works just fine. Thanks!

amandalucasp commented 1 year ago

Sorry, that was actually a mistake. For the models I trained from scratch, I had better results without the prediction head. But for the use case from this issue, since I’m trying to use your checkpoint, I should keep the prediction head.

Em sex., 5 de ago. de 2022 às 4:27 PM, Mido @.***> escreveu:

Ok so you're loading a model pre-trained with a prediction head, but you're resuming training without a prediction head?

— Reply to this email directly, view it on GitHub https://github.com/facebookresearch/suncet/issues/36#issuecomment-1206789502, or unsubscribe https://github.com/notifications/unsubscribe-auth/AUD6AV5FIT36AQZRX4P7NP3VXVTJJANCNFSM55TYC7RA . You are receiving this because you were mentioned.Message ID: @.***>