hunto / DIST_KD

Official implementation of paper "Knowledge Distillation from A Stronger Teacher", NeurIPS 2022
Apache License 2.0
138 stars 20 forks source link

How to train with custom datasets ? #7

Open ManuBN786 opened 1 year ago

ManuBN786 commented 1 year ago

I trained a resnet34 teacher on my custom dataset with 9 classes. I arranged the dataset in the imagenet format. I modified the dataset/builder.py like this:

pre-configuration for the dataset

if args.dataset == 'imagenet':
    args.data_path = 'data/imagenet' if args.data_path == '' else args.data_path
    args.num_classes = 9
    args.input_shape = (3, 384, 384)

I used the command "python tools/train.py --dataset imagenet --data-path data/imagenet/ --model resnet34 -c configs/strategies/resnet/resnet.yaml --teacher-pretrained --image-mean 0.604 0.327 0.249 --image-std 0.109 0.076 0.070 -b 32 --experiment teacher_model_train --epochs 100"

Even after 100 epochs it show the best.pt accuracy as 0.3 !!

After that I tried to train a student resnet18 with the command:

"python tools/train.py --dataset imagenet --data-path data/imagenet/ --model resnet18 -c configs/strategies/distill/resnet_dist.yaml --image-mean 0.604 0.327 0.249 --image-std 0.109 0.076 0.070 --teacher-pretrained --teacher-ckpt experiments/teacher_model_train/best.pth.tar -b 16 --experiment student_model_train --epochs 100"

it shows this error:

12:29:01 INFO Model resnet18 created, params: 11.181 M, FLOPs: 5.330 G 12:29:02 INFO Loading pretrained checkpoint from experiments/teacher_model_train/best.pth.tar Traceback (most recent call last): File "tools/train.py", line 363, in main() File "tools/train.py", line 91, in main teacher_model = build_model(args, args.teacher_model, args.teacher_pretrained, args.teacher_ckpt) File "/home/manu/PycharmProjects/DIST_KD/classification/tools/models/builder.py", line 71, in build_model model.load_state_dict(ckpt, strict=False) File "/home/manu/.virtualenvs/dl4cv/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1497, in load_state_dict raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format( RuntimeError: Error(s) in loading state_dict for ResNet: size mismatch for fc.weight: copying a param with shape torch.Size([9, 512]) from checkpoint, the shape in current model is torch.Size([1000, 512]). size mismatch for fc.bias: copying a param with shape torch.Size([9]) from checkpoint, the shape in current model is torch.Size([1000]).

Please tell me how to train with custom datasets.

ManuBN786 commented 1 year ago

I could fix the error by by changing the teacher model name in 'configs/strategies/distill/resnet_dist.yaml' from 'tv_resnet34' to 'resnet34'. Now the student model trains well.

But I don't know how to improve the teacher model accuracy

hunto commented 1 year ago

Dear @ManuBN786 ,

Sorry for the late reply. Have you tried your dataset and training settings on your training framework or other frameworks?

ManuBN786 commented 1 year ago

Yes on a resnet50 from pytorch, it give a validation accuracy of 0.93.

I dont know how using DSIT_KD the validation accuracy is so poor

hunto commented 1 year ago

One bug I can find is that your training uses input images with 384x384 resolution, but the resolution in our framework is set to 224 with hard code. (see build_train_transforms and build_val_transforms in https://github.com/hunto/image_classification_sota/blob/main/lib/dataset/transform.py)

You should manually change all the 224 to 384 at L21, L32, and L61; and change 256 to 440 at L60.

ManuBN786 commented 1 year ago

Ok. Thanks for letting me know.

ManuBN786 commented 1 year ago

low_acc

I did all of the above mentioned for image size 384, but I still get a very low accuracy for the teacher.

hunto commented 1 year ago

It's difficult for me to identify the differences between this repo and the example code by pytorch. If you want to use DIST KD in your project, I think the easiest way is to add KD code in our existing and valid code (You just need to initialize a pretrained teacher, compute its outputs wrt the batch input, and compute and backward the KD loss).