nomewang / M3DM

MIT License
153 stars 21 forks source link

About labels in contrastive loss #7

Closed SionAn closed 1 year ago

SionAn commented 1 year ago

Thank you for your interesting work and sharing the source code. I have a question about the labels used in contrastive loss when training UFF model. "labels = (torch.arange(N, dtype=torch.long)+ N torch.distributed.get_rank()).cuda()" is necessary? As far as I know, each gpu computes the features from part of the batch. i.e., 64 samples among 256 samples when using 4 gpus. Then, N is 64 x 28 x 28 if there are no zero patches. Thus, the label should have integer values from 0 to N-1 in all gpus since they do not combine the logits, but gradients. In my case, "N torch.distributed.get_rank()" makes an error.

nomewang commented 1 year ago

Thanks for your comments! The contrastive code indeed has a bug in multiple GPU training, and we will fix it soon. Our model is trained on a single model, and this bug doesn't influence the result.

SionAn commented 1 year ago

Thank you for your response. I have another question. In the paper, you mentioned the 750 iterations model is used for evaluation. Is it about 100 epochs, right?

nomewang commented 1 year ago

In the paper, the batchsize is calculated with useful patchs. In the repo, you can just set --batch_size=16, and train just 1-3 epochs for UFF module.

SionAn commented 1 year ago

Thank you for your kind response. I am trying to reproduce, but the performance is significantly dropped when using the UFF model. (For bagel, I-AUROC is 0.769 with UFF, but 0.994 without UFF) I think there is an issue with the training process of UFF even if I just run the code following your comments. Can I get the trained UFF model?

nomewang commented 1 year ago

We will release the trained-UFF model soon. You only need to train the UFF module with only 1-3 epochs, and make sure that you use all 10 categories of data during training this module.

nomewang commented 1 year ago

We have released a checkpoint file of UFF module now.