tk1980 / TwoWayMultiLabelLoss

Two-way Multi-Label Loss
34 stars 1 forks source link

loss can return NaN from the current implementation... #2

Open perathambkk opened 1 year ago

perathambkk commented 1 year ago

Thanks and congrats for this new SOTA work in multilabel loss function!!

However,...

If the batch (very likely small batch sizes like 50s) contains only negative samples (not well mixed up) and especially when the net/model correctly predicts all of them (or predicts all zeros for the whole batch), the current implementation in the repo of TwoWay loss return Nan.

It seems all four variables/terms in this line are all empty lists.

torch.nn.functional.softplus(nlogit_class + plogit_class).mean() + \ torch.nn.functional.softplus(nlogit_sample + plogit_sample).mean()

This can be easily produced by feeding a torch.zeros tensor (or any constant values) as the input parameter x to the loss. The class_mask and sample_mask become all False values.

(Naively implementing the TwoWay loss does not introduce this problem though....)

((Or just explicitly check for such empty torch tensors or adding some bias values to the net during the classifier head initialization....))

Jordy-VL commented 8 months ago

Good points! On a different dataset, I was not able to reproduce this loss function's superiority over ASL, for example. Batch size is hugely impactful for the sample-wise approximation part.