libffcv / ffcv

FFCV: Fast Forward Computer Vision (and other ML workloads!)
https://ffcv.io
Apache License 2.0
2.84k stars 178 forks source link

Fix bug in cifar-10 example training script; takes accuracy ~92.5% -> ~94% #213

Closed KellerJordan closed 2 years ago

KellerJordan commented 2 years ago

I was playing around with the example training script for CIFAR-10 and noticed that removing test-time augmentation improved my results from around 92-92.5% to 93% accuracy on the test set. It turns out that torch.fliplr isn't behaving as desired on batches of images, in particular it is causing a vertical rather than horizontal flip. Fixing this bug improves the accuracy with TTA to ~94%.

Here's a demonstration of the basic problem:

>>> img_t = torch.tensor([[[1, 2, 3],
                       [4, 5, 6],
                       [7, 8, 9]]])
>>> torch.fliplr(img_t), img_t.flip(-1)
(tensor([[[7, 8, 9],
          [4, 5, 6],
          [1, 2, 3]]]),
 tensor([[[3, 2, 1],
          [6, 5, 4],
          [9, 8, 7]]]))

And here are the results of running the training script before/after fixing.

Before fixing:

Total time: 37.84584
...
train accuracy: 98.0%
...
test accuracy: 92.3%

After fixing:

Total time: 37.21039
...
train accuracy: 99.6%
...
test accuracy: 94.2%

Thanks for the nice library!

lengstrom commented 2 years ago

thanks!