CVMI-Lab / SlotCon

(NeurIPS 2022) Self-Supervised Visual Representation Learning with Semantic Grouping
https://wen-xin.info/slotcon/
Apache License 2.0
95 stars 9 forks source link

ViT training ? #10

Open alexcbb opened 11 months ago

alexcbb commented 11 months ago

Hello,

Thank you for your very interesting work ! I'm currently trying to replicate your results with your provided codebase and I was wondering whether you also tested a Vision Transformer architecture as encoder ? You compared in the paper with DINO, but I wanted to know if you where able to get some properties close to what they obtained (a kind of saliency map with the attention map around the object of interest).

Thank you again for your response !

xwen99 commented 11 months ago

Hi @alexcbb, thanks for your attention to our work!

We actually didn't thoroughly experiment with ViTs due to computation constraints. Regarding the object-centric attention maps of DINO, we believe that is a merit of Transformers, and for CNNs we need to find another path. Our method explores doing it via explicit clustering on top of CNN features, which indeed worked. Besides that, we also tried to find similar visualizations within CNNs themselves, and we found PCA on dense feature maps produced plausible results. Due to the hierarchical structure of CNN, the resulting visualization's resolution is relatively low. We tried tricks like modifying the stride, which did not help much. Hope that is helpful for you!

alexcbb commented 11 months ago

Thank you for your quick answer !

I'm would be very interested to explore whether such training would be beneficial for Vision Transformer (even for a small version like ViT-S 16) : I'm first trying to check whether I can reproduce your results with ResNet and then wants to apply it to ViT. I think this could be beneficial to extract object knowledge to some extent and bring some prior for the training, and more again on scene-centric datasets.

Can I ask for your help in this process ? (mainly on the replication of the results)

Thanks again !

xwen99 commented 11 months ago

Feel free to leave a message if there is trouble working on that.

alexcbb commented 11 months ago

For the pre-training on COCO it is indicated that it was performed on 8 GPU NVIDIA 2080 Ti for 800 epochs. Do you have maybe an average time required for such training, and eventually some memory consumption information ?

xwen99 commented 11 months ago

It should took up almost all memories of 8x2080 Ti, roughly 80GB in total. I do not remember well the precise time it took for training, maybe roughly 2~3 days?

alexcbb commented 11 months ago

I made some small changes to launch the training (I created a Pytorch Lightning module to ease the deployment on clusters) and began to launch a 800 epoch training on COCO. Here is an overview of the current evolution of the loss (it is now at around 230 epochs after ~1 day of training), does it seems to be a right convergence curve ? Screenshot from 2023-12-19 17-30-11

alexcbb commented 11 months ago

It seems that I'm not able to replicate your figure 3 after pre-training and I don't understand why, the prototypes seems a bit weird and there's no mask on my final image

xwen99 commented 11 months ago

Fig 3 is simply produced using viz_slots.py with the default configs. The model is the default model on coco, with 800 epochs of training. Please check if there are any errors in your reimplementation.

alexcbb commented 10 months ago

Hello, the problem was in my visualization file, it seems that I'm now able to obtain well aligned concepts ! I've seen in the paper that you say one would need to scale the loss according to the batch size (if we would augment its size). Can you maybe tell me more about this ? (I've trained my model using your default parameters and batch size of 1536 without any huge problem on the results)

xwen99 commented 10 months ago

Hi, we scale the learning rate linearly with the batch size, as done by many previous works. This part is already implemented in the code, and basically no more modification is needed for you: https://github.com/CVMI-Lab/SlotCon/blob/dcbe6dc80a6745a711b741ed58120a42c3d69996/main_pretrain.py#L80

alexcbb commented 10 months ago

Hello, I would have a question concerning the slot loss part. You specify in equation (5) a masking over the slots that do not occupy dominating groups, that you then use for the computing of an InfoNCE loss. I was wondering in the code with the ctr_loss_filtered function of SlotCon why you would use the mask_intersection over mask_q to select the slots of q ? Is it in order to avoid slots that do not have positive pair in k, or is it another explanation ? Thank you in advance. Did you make any ablations on this masking on whether it was helping or no for the training ?

xwen99 commented 10 months ago

Your anticipation is correct, this is to make sure they form a positive pair, such that both the query and key slots exist across views. From my memory, we didn't ablate much on that.

alexcbb commented 10 months ago

Your anticipation is correct, this is to make sure they form a positive pair, such that both the query and key slots exist across views. From my memory, we didn't ablate much on that.

Ok, thank you again for your answer. Did you ever encounter the case where there's no positive pair in the views ? While I was trying to train the model with a ViT backbone, I obtained a NaN loss and the issue comes from the slot loss. At a certain point, it is not able to get any intersection mask. If you have encountered this during your experiments it would help me a lot !

xwen99 commented 10 months ago

Well actually I can't recall well about the details..., you may consider dropping that pair in this case

KJ-rc commented 9 months ago

Hi, @alexcbb @xwen99

I have just conducted a small-scale experiment with ViT-S on COCO for 100 epochs. The rest of the settings can be found below.

Prototype visualization makes sense but is weird. I am checking the code now. I would appreciate any hints or suggestions.

CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 torchrun --master_port 12348 --nproc_per_node=8 \
    main_pretrain.py \
    --dataset COCO \
    --data-dir ${data_dir} \
    --output-dir ${output_dir} \
    \
    --arch vit_small \
    --dim-hidden 4096 \
    --dim-out 256 \
    --num-prototypes 256 \
    --teacher-momentum 0.99 \
    --teacher-temp 0.07 \
    --group-loss-weight 0.5 \
    \
    --batch-size 256 \
    --optimizer adamw \
    --base-lr 5e-4 \
    --weight-decay 0.04 \
    --warmup-epoch 5 \
    --epochs 100 \
    --fp16 \
    \
    --print-freq 10 \
    --save-freq 50 \
    --auto-resume \
    --num-workers 12

slotcon_vits_coco_100eps

alexcbb commented 9 months ago

Hi, @alexcbb @xwen99

I have just conducted a small-scale experiment with ViT-S on COCO for 100 epochs. The rest of the settings can be found below.

Prototype visualization makes sense but is weird. I am checking the code now. I would appreciate any hints or suggestions.

CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 torchrun --master_port 12348 --nproc_per_node=8 \
    main_pretrain.py \
    --dataset COCO \
    --data-dir ${data_dir} \
    --output-dir ${output_dir} \
    \
    --arch vit_small \
    --dim-hidden 4096 \
    --dim-out 256 \
    --num-prototypes 256 \
    --teacher-momentum 0.99 \
    --teacher-temp 0.07 \
    --group-loss-weight 0.5 \
    \
    --batch-size 256 \
    --optimizer adamw \
    --base-lr 5e-4 \
    --weight-decay 0.04 \
    --warmup-epoch 5 \
    --epochs 100 \
    --fp16 \
    \
    --print-freq 10 \
    --save-freq 50 \
    --auto-resume \
    --num-workers 12

slotcon_vits_coco_100eps

Hello, on my side I was not able to make the training converge properly. The slot loss is returning time to time a NaN on masking and I don't know why this is happenning. Can I maybe know what changes you made to replace the backbone ? Concerning the hyperparameters I've got the same as yours (I've took the same hyperparameter as DINO training). Your prototypes looks coherent for me, what does seems weird for you ? I would gladly have a discussion with you about your re-implementation if you agree !

KJ-rc commented 9 months ago

Hi, @alexcbb @xwen99 I have just conducted a small-scale experiment with ViT-S on COCO for 100 epochs. The rest of the settings can be found below. Prototype visualization makes sense but is weird. I am checking the code now. I would appreciate any hints or suggestions.

CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 torchrun --master_port 12348 --nproc_per_node=8 \
    main_pretrain.py \
    --dataset COCO \
    --data-dir ${data_dir} \
    --output-dir ${output_dir} \
    \
    --arch vit_small \
    --dim-hidden 4096 \
    --dim-out 256 \
    --num-prototypes 256 \
    --teacher-momentum 0.99 \
    --teacher-temp 0.07 \
    --group-loss-weight 0.5 \
    \
    --batch-size 256 \
    --optimizer adamw \
    --base-lr 5e-4 \
    --weight-decay 0.04 \
    --warmup-epoch 5 \
    --epochs 100 \
    --fp16 \
    \
    --print-freq 10 \
    --save-freq 50 \
    --auto-resume \
    --num-workers 12

slotcon_vits_coco_100eps

Hello, on my side I was not able to make the training converge properly. The slot loss is returning time to time a NaN on masking and I don't know why this is happenning. Can I maybe know what changes you made to replace the backbone ? Concerning the hyperparameters I've got the same as yours (I've took the same hyperparameter as DINO training). Your prototypes looks coherent for me, what does seems weird for you ? I would gladly have a discussion with you about your re-implementation if you agree !

Sure. You can send me an email. ([Update] - my email: kjliu@vision.is.tohoku.ac.jp) For a quick answer, I tried to make minimum changes. Specifically,

Regarding prototype visualization, I found there are some empty prototypes, and in the 4th column from the right-hand side, it shows "cat", "cow" and "bear" while the ResNet50 based one can output a pure cat prototype. I would say the semantic consistency is lower in ViT-S based one. Again, I am not sure if this behavior is correct, or not.

alexcbb commented 9 months ago

@KJ-rc Ok thank you, can you provide me your mail, I'm not able to find it. Concerning the implementation from DINO you trained it from scratch, right ? I will try as you said, maybe there are some issues in my code I didn't saw and let you know if I have the same artifact as yours. But I'm pretty sure I've done the same as you did (very few changes)

But yes you are right I didn't saw the empty prototypes ! This is quite weird indeed. Maybe some changes in the hyperparameters can make it a bit better (like the temperature from the student/teacher ?). What about also let it train for longer (like 300~400 epochs) ? I will let you know as soon as I'm able to train the ViT from scratch on my side

alexcbb commented 9 months ago

@KJ-rc Just to be sure, as DINO do not use BatchNorm in its projection Head, did you also removed the BatchNorm (and subsequently the SyncBatchNorm calls in SlotCon) ?

KJ-rc commented 9 months ago

Hi, I did only the modifications listed above. I consider projectors to have a higher dependency on pre-trained methods rather than backbone architecture, so I keep the batch norm layer.

alexcbb commented 9 months ago

It seems that by reducing the batch size to 256 and adding a gradient clipping the training is now working. I'll see how it evolves and let you know about my final results !

alexcbb commented 9 months ago

Hello, it seems that I've got the same issue as @KJ-rc when training SlotCon with ViT but I think this issue is because there are "dead slot" appearing during the training as said in the annex D. I've tested to print out 100 slots to check for the semantics and over those 100, around 20 were "dead slot" without any meaning. It seems to be quite related to the discussion in Annex D, what do you think ? res3

xwen99 commented 9 months ago

See if this paper helps you understand the dead slots: https://openreview.net/forum?id=Z2dVrgLpsF

alexcbb commented 9 months ago

See if this paper helps you understand the dead slots: https://openreview.net/forum?id=Z2dVrgLpsF

Thank you for the sharing. It would be interesting to evaluate whether the problem would disappear with such regularization !