Closed carloalbertobarbano closed 3 years ago
Hi, what is your torch version? We only tested our code with torch==1.1.0 and torchvision==0.2.2.post3
Hi, these are the torch versions I installed (I used your requirements.txt file):
torch==1.1.0
torchvision==0.2.2.post3
@carloalbertobarbano
I just found that the interface of outer_criterion
is wrong.
https://github.com/clovaai/rebias/blob/master/trainer.py#L338-L341
_g_preds, _g_feats = g_net(x)
_f_loss_indep = self.outer_criterion(f_feats, _g_feats, labels=labels, f_pred=preds)
f_loss_indep += _f_loss_indep
Learned Mixin + H does not use features, but only predictions.
I have to update outer_criterion
interface to get g_pred
as an argument. I will revise this in a few days.
Until then, please change the code.
_f_loss_indep = self.outer_criterion(f_feats, _g_feats, labels=labels, f_pred=preds)
as
_f_loss_indep = self.outer_criterion(f_feats, _g_preds, labels=labels, f_pred=preds)
for LearnedMixin + H.
It is very strange because all arguments and all numbers are verified before I release the code. Sorry for the inconvenience, and thanks for reporting the bug
Thanks a lot!
@carloalbertobarbano This issue is resolved by #10 I tested that ReBias works well by the following command
python main_biased_mnist.py --root /home/data --train_correlation 0.99 --optim AdamP
I have confirmed that ReBias unbiased accuracy is 90.11% with AdamP at rho=0.99, where my previous result was 89.60% (average over three different runs)
[2021-07-04 15:11:35] state dict is saved to ./checkpoints/last.pth, metadata: {'cur_epoch': 80, 'best_acc': 0.9011, 'scores': {'biased': {'f_acc': 1.0, 'g_0_acc': 0.9999, 'outer_0_loss': 2.9878953809384255e-05, 'inner_0_loss': -2.9878953809384255e-05}, 'rho0': {'f_acc': 0.8895, 'g_0_acc': 0.0098, 'outer_0_loss': 6.879471498859857e-06, 'inner_0_loss': -6.879471498859857e-06}, 'unbiased': {'f_acc': 0.9011, 'g_0_acc': 0.1088, 'outer_0_loss': 6.648345147550571e-06, 'inner_0_loss': -6.648345147550571e-06}}}
If you have any other questions, please don't hesitate to bother me
Thanks a lot for your support
Hello, I'm trying to run your LearnedMixin implementation on MNIST but I'm getting the following error: