zhanghang1989 / PyTorch-Encoding

A CV toolkit for my papers.
https://hangzhang.org/PyTorch-Encoding/
MIT License
2.04k stars 450 forks source link

Transfer learning unable to do #391

Open ravitejarj opened 3 years ago

ravitejarj commented 3 years ago

Hi @zhanghang1989 I am unable to do transfer learning on the model i have downloaded get_deeplab_resnest101_ade and when i changed the no of classes in ade20k.py (no of classes 8) pre trained model is not loading ( getting error )

So I have changed the code for Transfer learning

Code changes:

1) deeplab.py in get_deeplab_resnest101_ade function changed from from

model = DeepLabV3(datasets[dataset.lower()].NUM_CLASS, backbone=backbone, root=root, **kwargs)

to no_of_classes = 150 model = DeepLabV3(no_of_classes, backbone=backbone, root=root, **kwargs)

so i can load pretrained model with 150 classes then 2) In train_dist.py file

Model loading

model_ft = get_deeplab_resnest101_ade(pretrained=True)

for param in model.parameters(): param.requires_grad = False

model_ft.head.block = Sequential( (Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)), (BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)), (ReLU(inplace=True)), (Dropout(p=0.1, inplace=False)), (Conv2d(256, 8, kernel_size=(1, 1), stride=(1, 1))))

for param in model_ft.head.parameters(): param.requires_grad = True

Training

python train_dist.py --dataset ade20k --model deeplab --aux --backbone resnest101 --ft --epochs 100

after successful training i am getting 150 classes output not 8 classes(i have given 8 classes in last layer) i need 8 classes output

can you help me with this

zhanghang1989 commented 3 years ago

An easy solution is set strict=False when loading the pretrained model https://pytorch.org/docs/master/generated/torch.nn.Module.html?highlight=load_state_dict#torch.nn.Module.load_state_dict

ravitejarj commented 3 years ago

Thank you