clovaai / rebias

Official Pytorch implementation of ReBias (Learning De-biased Representations with Biased Representations), ICML 2020
MIT License
171 stars 30 forks source link

RuntimeError for LearnedMixin on MNIST #9

Closed carloalbertobarbano closed 3 years ago

carloalbertobarbano commented 3 years ago

Hello, I'm trying to run your LearnedMixin implementation on MNIST but I'm getting the following error:

[2021-07-03 06:58:30] start training
Traceback (most recent call last):
  File "main_biased_mnist.py", line 138, in <module>
    fire.Fire(main)
  File "/home/barbano/.pyenv/versions/rebias/lib/python3.7/site-packages/fire/core.py", line 141, in Fire
    component_trace = _Fire(component, args, parsed_flag_args, context, name)
  File "/home/barbano/.pyenv/versions/rebias/lib/python3.7/site-packages/fire/core.py", line 471, in _Fire
    target=component.__name__)
  File "/home/barbano/.pyenv/versions/rebias/lib/python3.7/site-packages/fire/core.py", line 681, in _CallAndUpdateTrace
    component = fn(*varargs, **kwargs)
  File "main_biased_mnist.py", line 134, in main
    save_dir=save_dir)
  File "/home/barbano/rebias/trainer.py", line 390, in train
    self._train_epoch(tr_loader, cur_epoch)
  File "/home/barbano/rebias/trainer.py", line 362, in _train_epoch
    self._update_f(x, labels, loss_dict=loss_dict, prefix='train__')
  File "/home/barbano/rebias/trainer.py", line 340, in _update_f
    _f_loss_indep = self.outer_criterion(f_feats, _g_feats, labels=labels, f_pred=preds)
  File "/home/barbano/.pyenv/versions/rebias/lib/python3.7/site-packages/torch/nn/modules/module.py", line 493, in __call__
    result = self.forward(*input, **kwargs)
  File "/home/barbano/rebias/criterions/comparison_methods.py", line 92, in forward
    loss = F.cross_entropy(f_pred+g_pred, labels)
RuntimeError: The size of tensor a (10) must match the size of tensor b (128) at non-singleton dimension 1
SanghyukChun commented 3 years ago

Hi, what is your torch version? We only tested our code with torch==1.1.0 and torchvision==0.2.2.post3

carloalbertobarbano commented 3 years ago

Hi, these are the torch versions I installed (I used your requirements.txt file):

torch==1.1.0
torchvision==0.2.2.post3
SanghyukChun commented 3 years ago

@carloalbertobarbano I just found that the interface of outer_criterion is wrong. https://github.com/clovaai/rebias/blob/master/trainer.py#L338-L341

                _g_preds, _g_feats = g_net(x)

                _f_loss_indep = self.outer_criterion(f_feats, _g_feats, labels=labels, f_pred=preds)
                f_loss_indep += _f_loss_indep

Learned Mixin + H does not use features, but only predictions. I have to update outer_criterion interface to get g_pred as an argument. I will revise this in a few days.

Until then, please change the code.

                _f_loss_indep = self.outer_criterion(f_feats, _g_feats, labels=labels, f_pred=preds)

as

                _f_loss_indep = self.outer_criterion(f_feats, _g_preds, labels=labels, f_pred=preds)

for LearnedMixin + H.

It is very strange because all arguments and all numbers are verified before I release the code. Sorry for the inconvenience, and thanks for reporting the bug

carloalbertobarbano commented 3 years ago

Thanks a lot!

SanghyukChun commented 3 years ago

@carloalbertobarbano This issue is resolved by #10 I tested that ReBias works well by the following command

python main_biased_mnist.py --root /home/data --train_correlation 0.99 --optim AdamP

I have confirmed that ReBias unbiased accuracy is 90.11% with AdamP at rho=0.99, where my previous result was 89.60% (average over three different runs)

[2021-07-04 15:11:35] state dict is saved to ./checkpoints/last.pth, metadata: {'cur_epoch': 80, 'best_acc': 0.9011, 'scores': {'biased': {'f_acc': 1.0, 'g_0_acc': 0.9999, 'outer_0_loss': 2.9878953809384255e-05, 'inner_0_loss': -2.9878953809384255e-05}, 'rho0': {'f_acc': 0.8895, 'g_0_acc': 0.0098, 'outer_0_loss': 6.879471498859857e-06, 'inner_0_loss': -6.879471498859857e-06}, 'unbiased': {'f_acc': 0.9011, 'g_0_acc': 0.1088, 'outer_0_loss': 6.648345147550571e-06, 'inner_0_loss': -6.648345147550571e-06}}}

If you have any other questions, please don't hesitate to bother me

carloalbertobarbano commented 3 years ago

Thanks a lot for your support