KaiyangZhou / pytorch-center-loss

Pytorch implementation of Center Loss
MIT License
965 stars 218 forks source link

How many classes? #19

Closed msaadsaeed closed 1 year ago

msaadsaeed commented 3 years ago

Hello! I am trying to train a model with central loss. My dataset has 901 classes. I am creating mini batches after shuffling the training data with batch size of 128 such that:

feats_dim = [128,512] per batch labels_dim =[128] per batch

Then: center_loss = CenterLoss(num_classes = ????, feat_dim = 512, use_gpu=False)

What should I pass in num_classes? 901: actual number of classes 128: classes in current batch

hamedrq7 commented 1 year ago

From my understanding of the code, num_classes in CenterLoss is used to create the final one_hot vector (mask variable), so you should pass 901 even if there aren't 901 classes in current mini_batch. also in the original paper it was mentioned that since we are using mini batches, some centers may not get updated in one mini batch, the 901-128 other classes that are not in the current minibatch are not updated.

msaadsaeed commented 1 year ago

From my understanding of the code, num_classes in CenterLoss is used to create the final one_hot vector (mask variable), so you should pass 901 even if there aren't 901 classes in current mini_batch. also in the original paper it was mentioned that since we are using mini batches, some centers may not get updated in one mini batch, the 901-128 other classes that are not in the current minibatch are not updated.

Yes, exactly. Thanks for the nice explanation.