microsoft / FocalNet

[NeurIPS 2022] Official code for "Focal Modulation Networks"
MIT License
682 stars 61 forks source link

How to use FocalNet-DINO pretrained with Object365 #25

Closed Kun-Ming closed 1 year ago

Kun-Ming commented 1 year ago

Hi,

Thank you for sharing the work!

I am using your model pretrained with Objecet365 dataset which you indicated here: https://github.com/FocalNet/FocalNet-DINO. While when I run the code in https://github.com/FocalNet/FocalNet-DINO/blob/main/inference_and_visualization.ipynb in order to predict for some images, it reports bugs on /FocalNet-DINO/models/dino/backbone.py: NotImplementedError: Unknown backbone focalnet_large_fl4_pretrained_on_o365 I added focalnet_large_fl4_pretrained_on_o365 in the dict at https://github.com/FocalNet/FocalNet-DINO/blob/main/models/dino/backbone.py#L229 and https://github.com/FocalNet/FocalNet-DINO/blob/main/models/dino/backbone.py#L205, but it still came out that NotImplementedError: Unknown backbone focalnet_large_fl4_pretrained_on_o365 in the function https://github.com/FocalNet/FocalNet-DINO/blob/main/models/dino/focal.py#L515. I found I cannot fix it due to the unknown parameter in https://github.com/FocalNet/FocalNet-DINO/blob/main/models/dino/focal.py#L531.

Could you help me to use the model? Thank you for your time.

jwyang commented 1 year ago

Hi, which config did you use to load the ckpt?

Kun-Ming commented 1 year ago

Hi, I am using https://github.com/FocalNet/FocalNet-DINO/blob/main/config/DINO/DINO_5scale_focalnet_large_fl4.py. But the backbone in this file is focalnet_L_384_22k_fl4 and I change to focalnet_large_fl4_pretrained_on_o365 because it could not find focalnet_L_384_22k_fl4.pth. And I try to fix bug occurred after this change. I copied https://github.com/FocalNet/FocalNet-DINO/blob/main/models/dino/focal.py#L543 to a new dict called focalnet_large_fl4_pretrained_on_o365 in this func. After load focalnet_large_fl4_pretrained_on_o365.pth, the focalnet does not report size mismatch but encoder and decoder:

size mismatch for transformer.decoder.class_embed.0.weight: copying a param with shape torch.Size([366, 256]) from checkpoint, the shape in current model is torch.Size([91, 256]).
    size mismatch for transformer.decoder.class_embed.0.bias: copying a param with shape torch.Size([366]) from checkpoint, the shape in current model is torch.Size([91]).
    size mismatch for transformer.decoder.class_embed.1.weight: copying a param with shape torch.Size([366, 256]) from checkpoint, the shape in current model is torch.Size([91, 256]).
    size mismatch for transformer.decoder.class_embed.1.bias: copying a param with shape torch.Size([366]) from checkpoint, the shape in current model is torch.Size([91]).
    size mismatch for transformer.decoder.class_embed.2.weight: copying a param with shape torch.Size([366, 256]) from checkpoint, the shape in current model is torch.Size([91, 256]).
    size mismatch for transformer.decoder.class_embed.2.bias: copying a param with shape torch.Size([366]) from checkpoint, the shape in current model is torch.Size([91]).
    size mismatch for transformer.decoder.class_embed.3.weight: copying a param with shape torch.Size([366, 256]) from checkpoint, the shape in current model is torch.Size([91, 256]).
    size mismatch for transformer.decoder.class_embed.3.bias: copying a param with shape torch.Size([366]) from checkpoint, the shape in current model is torch.Size([91]).
    size mismatch for transformer.decoder.class_embed.4.weight: copying a param with shape torch.Size([366, 256]) from checkpoint, the shape in current model is torch.Size([91, 256]).
    size mismatch for transformer.decoder.class_embed.4.bias: copying a param with shape torch.Size([366]) from checkpoint, the shape in current model is torch.Size([91]).
    size mismatch for transformer.decoder.class_embed.5.weight: copying a param with shape torch.Size([366, 256]) from checkpoint, the shape in current model is torch.Size([91, 256]).
    size mismatch for transformer.decoder.class_embed.5.bias: copying a param with shape torch.Size([366]) from checkpoint, the shape in current model is torch.Size([91]).
    size mismatch for transformer.enc_out_class_embed.weight: copying a param with shape torch.Size([366, 256]) from checkpoint, the shape in current model is torch.Size([91, 256]).
    size mismatch for transformer.enc_out_class_embed.bias: copying a param with shape torch.Size([366]) from checkpoint, the shape in current model is torch.Size([91]).
    size mismatch for label_enc.weight: copying a param with shape torch.Size([401, 256]) from checkpoint, the shape in current model is torch.Size([92, 256]).
    size mismatch for class_embed.0.weight: copying a param with shape torch.Size([366, 256]) from checkpoint, the shape in current model is torch.Size([91, 256]).
    size mismatch for class_embed.0.bias: copying a param with shape torch.Size([366]) from checkpoint, the shape in current model is torch.Size([91]).
    size mismatch for class_embed.1.weight: copying a param with shape torch.Size([366, 256]) from checkpoint, the shape in current model is torch.Size([91, 256]).
    size mismatch for class_embed.1.bias: copying a param with shape torch.Size([366]) from checkpoint, the shape in current model is torch.Size([91]).
    size mismatch for class_embed.2.weight: copying a param with shape torch.Size([366, 256]) from checkpoint, the shape in current model is torch.Size([91, 256]).
    size mismatch for class_embed.2.bias: copying a param with shape torch.Size([366]) from checkpoint, the shape in current model is torch.Size([91]).
    size mismatch for class_embed.3.weight: copying a param with shape torch.Size([366, 256]) from checkpoint, the shape in current model is torch.Size([91, 256]).
    size mismatch for class_embed.3.bias: copying a param with shape torch.Size([366]) from checkpoint, the shape in current model is torch.Size([91]).
    size mismatch for class_embed.4.weight: copying a param with shape torch.Size([366, 256]) from checkpoint, the shape in current model is torch.Size([91, 256]).
    size mismatch for class_embed.4.bias: copying a param with shape torch.Size([366]) from checkpoint, the shape in current model is torch.Size([91]).
    size mismatch for class_embed.5.weight: copying a param with shape torch.Size([366, 256]) from checkpoint, the shape in current model is torch.Size([91, 256]).
    size mismatch for class_embed.5.bias: copying a param with shape torch.Size([366]) from checkpoint, the shape in current model is torch.Size([91]).
jwyang commented 1 year ago

Great, actually, you do not need to change the config but just specify the "--pretrain_model_path".

To mitigate the above problem, you can simply append "--finetune_ignore class_embed label_enc" to your command line. This will let the checkpoint loading ignore class_embed parts as you showed above.

Kun-Ming commented 1 year ago

Thanks for your advice. Does class_embed not necessary for model inference? I am using the raw o365 pretrained model without fine tune on any other datasets. After ignored class_embed and label_enc, I use it for inference on single image. But the output is not reasoning, the box on the image did not bound the object. Here is my code:

# Load model
model_config_path = "config/DINO/DINO_5scale_focalnet_large_fl4.py"
model_checkpoint_path = "focalnet_large_fl4_pretrained_on_o365.pth"
_ignorekeywordlist = ['class_embed', 'label_enc']

checkpoint = torch.load(model_checkpoint_path, map_location='cpu')['model']
_tmp_st = OrderedDict({k:v for k, v in clean_state_dict(checkpoint).items() if check_keep(k, _ignorekeywordlist)})

args = SLConfig.fromfile(model_config_path)
args.device = 'cuda' 
model, criterion, postprocessors = build_model_main(args)
model.load_state_dict(_tmp_st)
_ = model.eval()

# Inference
image = Image.open("animal.jpeg").convert("RGB")
transform = T.Compose([
    T.RandomResize([800], max_size=1333),
    T.ToTensor(),
    T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
image, _ = transform(image, None)
output = model.cuda()(image[None].cuda())
output = postprocessors['bbox'](output, torch.Tensor([[1.0, 1.0]]).cuda())[0]

# Visualization
thershold = 0.05
vslzr = COCOVisualizer()

scores = output['scores']
labels = output['labels']
boxes = box_ops.box_xyxy_to_cxcywh(output['boxes'])
select_mask = scores > thershold

pred_dict = {
    'boxes': boxes[select_mask],
    'size': torch.Tensor([image.shape[1], image.shape[2]]),
}
vslzr.visualize(image, pred_dict, savedir=None, dpi=100)

The result looks like this: 截屏2023-02-14 下午4 42 24 The anchor points does not look like in the correct position. I checked the size of input, according to https://github.com/FocalNet/FocalNet-DINO/blob/main/datasets/coco.py#L487, looks not bad.

Could you please give me some suggestions to solve this?

jwyang commented 1 year ago

For direct inference, you need to make sure all parameters are loaded correctly. No single parameter should be left behind. But for further finetuning on COCO, you can try to abandon the head part.

Kun-Ming commented 1 year ago

Thanks, so how to correctly load class_embed and label_enc that it could not loaded now? Is there someting wrong during load the model?

jwyang commented 1 year ago

it is probably because your config was wrong. Please make sure num_classes=366 instead of 91 in your model config if you want to load the O365 model for evaluation.

Kun-Ming commented 1 year ago

Thanks! That solves my problem. Except set num_calsses in config file, I also need to change dn_labelbook_size to 400. What does this parameter use for?