Closed mingqiJ closed 2 years ago
@mingqiJ you have to strip the aux batchnorm from the checkpoints if you want to use them later
clean_checkpoint.py
can help with that
https://github.com/rwightman/pytorch-image-models/blob/master/clean_checkpoint.py#L48
Hi @rwightman Thanks for your reply. Like using this comment to run?
python3 clean_checkpoint.py --checkpoint checkpoint-2.pth.tar --clean-aux-bn
@mingqiJ I assume you've tried this by now? please close if yes
Yes, I just finish this, thanks for your help!!!
@mingqiJ I assume you've tried this by now? please close if yes
But maybe I also have another question and put it in the discussion, could you help me to see that?
When I using this command to train a ResNet50(ResNet50 with JSD loss and RandAugment (clean + 2x RA augs) )
./distributed_train.sh 2 /imagenet -b 64 --model resnet50 --sched cosine --epochs 200 --lr 0.05 --amp --remode pixel --reprob 0.6 --aug-splits 3 --aa rand-m9-mstd0.5-inc1 --resplit --split-bn --jsd --dist-bn reduce
I get some checkpoint.pth.tar file. But I can not directly load the weight from these .pth.tar file using this command. Why?
checkpoint = torch.load('checkpoint-2.pth.tar')
model = timm.create_model('resnet50')
model.load_state_dict(checkpoint['state_dict'])
It takes error like this: RuntimeError: Error(s) in loading state_dict for ResNet: Unexpected key(s) in state_dict: "bn1.aux_bn.0.weight", "bn1.aux_bn.0.bias", "bn1.aux_bn.0.running_mean", "bn1.aux_bn.0.running_var", "bn1.aux_bn.0.num_batches_tracked", "bn1.aux_bn.1.weight", "bn1.aux_bn.1.bias", "bn1.aux_bn.1.running_mean"..........
It seems like it has more layers than the ResNet50.
Does anyone know why it is?