teddykoker / torchsort

Fast, differentiable sorting and ranking in PyTorch
https://pypi.org/project/torchsort/
Apache License 2.0
766 stars 33 forks source link

Reproducing CIFAR results #19

Open paganpasta opened 3 years ago

paganpasta commented 3 years ago

Thanks a lot for this implementation. I was wondering how can I use the repo to reproduce the results on CIFAR as reported in the paper. As I understand, the target one-hot encoding will serve as top-k classification(k=1). But, after obtaining the logits and passing through the softmax(putting output [0, 1]) the objective is to make the output follow the target ordering. How can this be achieved?

teddykoker commented 3 years ago

In the paper they follow the method described in Differentiable Ranking and Sorting using Optimal Transport :

Screen Shot 2021-05-03 at 4 16 57 PM Screen Shot 2021-05-03 at 4 17 33 PM

I believe this is what you would need to do:

import torch
import torchsort

def J(u, k=1):
    return torch.nn.functional.relu(u - k + 1)

BATCH_SIZE = 64
L = 10 # number of classes

f_w = torch.randn(BATCH_SIZE, L) # logits of the CNN
l = torch.randint(high=L, size=(BATCH_SIZE, 1)) # integer labels

R = torchsort.soft_rank(f_w) # compute ranks of logits

R_l = R.gather(-1, l) # index ranks at label (l)

loss = J(L - R_l).mean()
...

Intuitively this makes sense as we are effectively maximizing the rank of the correctly labelled logit. Once I have some more time on my hands I will work to fully reproduce their results, but this should be enough for you to get started!

teddykoker commented 3 years ago

Alright training on CIFAR10 was quite straightforward, I uploaded an example to extra/cifar10.py. Neither this paper nor the previously mentioned one give exact details about the architecture they use (kernel size, hidden size, data augmentation, etc.). They both write:

We use a vanilla CNN (4 Conv2D with 2 maxpooling layers, ReLU activation, 2 fully connected layers with batch norm on each), the ADAM optimizer (Kingma & Ba, 2014) with a constant step size of 10^−4

cifar10 loss plot

Creating a model following their description as best I could I was able to obtain a test accuracy of ~0.86, with the 600 epochs only taking about 3 hours on my NVIDIA 2070. I think this is probably the best you can do in terms of reproduction without reaching out to the original authors for these details.

You can run the example with:

python extra/cifar10.py
paganpasta commented 3 years ago

Wow! this was super helpful. Thanks @teddykoker

vltanh commented 3 years ago

Hi, thanks for the amazing work. However, I tried to reproduce the paper's results by comparing this to the conventional cross-entropy loss and noticed a few things:

I found the implementation for the loss in Cuturi et al., 2019 here. Would you mind checking it out?

The paper also mentions that

Similarly to Cuturi et al. (2019), we found that squashing the scores θ to [0, 1]^n with a logistic map was beneficial.

I wonder what this means. Should we do a softmax before ranking? Does it matter?

teddykoker commented 3 years ago

Thanks for the comment @vltanh. Following you advice I have made a few changes to the cifar10 script (still a work in progress on the fix_topk branch). Adding a softmax prior to the topk loss yields a test accuracy about the same as cross entropy:

Since these are both considerably higher than the paper reports, I will be trying again without data augmentation. Additionally the regularization coefficient may need to be tuned to obtain the best results.

vltanh commented 3 years ago

I have done experiments with both not using softmax and using softmax and I find that they achieve similar results (still, it's just on 1 test, so I'm not sure about that either).

For the current case where you use augmentation, my result is similar to you: CELoss outperforms at almost every step until the end where the two performs similarly.

For the case where you don't use augmentation, you might see that: CELoss quickly overfits (accuracy drops after the first 30-50 epochs) while TopK loss gradually increases, but they reach similar accuracy at the end too (on CIFAR100, TopK is worse).

About tuning the regularization coefficient, I tried with L2 for 0.1, 1.0, 10.0 and they achieve similar results, although 1.0 is marginally better.

teddykoker commented 3 years ago

Here is what I got with no augmentation; this looks similar to what you mention:

cifar10_test_accuracy

teddykoker commented 3 years ago

I think at this point it would make sense to reach out to the original authors and see if they could provide any insight as to what we are doing wrong. My soft sort/rank is tested to perform numerically identically to the original implementation, so it is likely just an issue with the loss function, or some difference in CNN architecture that provides different results.