Open todayplusplus opened 1 year ago
Hi @todayplusplus , thanks for raising these issues! We appreciate your rigorous write-up. I apologize for the challenges you encountered with our code!
The original codebase was jointly structured with our HILL mixup work (https://arxiv.org/abs/2211.01202), as I worked on them in tandem for my MPhil thesis. I tried to repartition the code, but forgot to remove that line!
For the num_classes, yes this is a typo, apologies! You can see though that we override the num_classes here so it would not have an impact (but is still not good code): https://github.com/cambridge-mlg/cifar-10s/blob/master/computational_experiments/utils.py#L67
As for the other shaping bug -- what data are you using as your train / test data? The target should be an object of size: [batch_size, num_classes]. But it appears to not be doing that here?
Perhaps in test...:
if len(targets.shape) == 1: # if scalar index -> one-hot encode
targets = F.one_hot(targets, num_classes=num_classes)
Again, apologies that this code is in a bit of a messy state; we intended to write an updated training using huggingface / timm (https://huggingface.co/docs/timm/index), as this code is somewhat outdated, but have not yet done so.
Thank you for your patient response. I will review my code based on your suggestions and hints. Thank you again, best wishes.
1) shape error in
test()
oftrain.py
In
line 308
of functiontest()
intrain.py
Running it directly will result in an error, after added an exception handler that showed that there was a problem with the
targets.data
's shape, the output isSince shape of
target.data
istorch.Size([100])
, use_, max_likely = torch.max(targets.data, 1)
will raise an error.I fix this bug by directly
max_likely = targets.data
but the output is worse, all of the loss isNAN
I would like to know if I made any mistakes that resulted in the error. I hope you can check your original code. Thank you.
2) May be some slip of a pen in
utils.py
line 239
oftrain.py
we usecriterion = utils.cross_entropy_loss
But in
utils.py
the function isdef cross_entropy(preds, trgts, num_class=10)
, notice that the Function parameters seems to be slip of a pen? i guess that you may want to usedef cross_entropy(preds, trgts, num_classes=10)
, since Intrain.py
you useloss = criterion(outputs, targets, num_classes=num_classes)
instead ofnum_class=num_classes
.And at last, some libraries are also imported in the code, but they do not exist, such as
from data import CIFARMixHILL