IDEA-Research / DINO

[ICLR 2023] Official implementation of the paper "DINO: DETR with Improved DeNoising Anchor Boxes for End-to-End Object Detection"
Apache License 2.0
2.08k stars 228 forks source link

Training from a pretrained checkpoint not working for Swin 5scale #246

Open jmunshi1990 opened 5 months ago

jmunshi1990 commented 5 months ago

Hi team,

I am trying to train/fine-tune swin_5scale model on a custom dataset which has 7 classes to train on. below is the script to start the training. I am using one of the uploaded 5scale_swin pth checkpoint files but I am getting below error while it starts to train the model.

PS: The same training script with the dataset works for a "from-the-scratch" training.

coco_path=$1 backbone_dir=$2 export CUDA_VISIBLE_DEVICES=$3 && python main.py \ --output_dir logs/DINO/dino_cdp_swin_5scale_finetune_test -c config/DINO/DINO_5scale_swin.py --coco_path $coco_path \ --options dn_scalar=100 embed_init_tgt=TRUE \ dn_label_coef=1.0 dn_bbox_coef=1.0 use_ema=False \ dn_box_noise_scale=1.0 backbone_dir=$backbone_dir \ num_classes=7 dn_labelbook_size=8 epochs=36 \ batch_size=8 --pretrain_model_path pretrained_weights/checkpoint0027_5scale_swin.pth \ --finetune_ignore label_enc.weight class_embed

Error:

/anaconda/envs/wli_dino/lib/python3.9/site-packages/torch/utils/data/dataloader.py:560: UserWarning: This DataLoader will create 10 worker processes in total. Our suggested max number of worker in current system is 8, which is smaller than what this DataLoader is going to create. Please be aware that excessive worker creation might get DataLoader running slow or even freeze, lower the worker number to avoid potential slowness/freeze if necessary. warnings.warn(_create_warning_msg( [01/23 19:09:31.216]: Ignore keys: [] Traceback (most recent call last): File "****/wli_ai_defect/DINO/main.py", line 388, in main(args) File "/wli_ai_defect/DINO/main.py", line 242, in main _load_output = model_without_ddp.load_state_dict(_tmp_st, strict=False) File "*****/anaconda/envs/wli_dino/lib/python3.9/site-packages/torch/nn/modules/module.py", line 2041, in load_state_dict raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format( RuntimeError: Error(s) in loading state_dict for DINO: size mismatch for input_proj.0.0.weight: copying a param with shape torch.Size([256, 192, 1, 1]) from checkpoint, the shape in current model is torch.Size([256, 384, 1, 1]). size mismatch for input_proj.1.0.weight: copying a param with shape torch.Size([256, 384, 1, 1]) from checkpoint, the shape in current model is torch.Size([256, 768, 1, 1]). size mismatch for input_proj.2.0.weight: copying a param with shape torch.Size([256, 768, 1, 1]) from checkpoint, the shape in current model is torch.Size([256, 1536, 1, 1]). size mismatch for input_proj.3.0.weight: copying a param with shape torch.Size([256, 1536, 1, 1]) from checkpoint, the shape in current model is torch.Size([256, 1536, 3, 3]). size mismatch for input_proj.4.0.weight: copying a param with shape torch.Size([256, 1536, 3, 3]) from checkpoint, the shape in current model is torch.Size([256, 256, 3, 3]).

Raneem-MT commented 2 months ago

Hi @jmunshi1990 Have you managed to solve this?

tahirashehzadi commented 2 months ago

Have you managed to solve this?