facebookresearch / ijepa

Official codebase for I-JEPA, the Image-based Joint-Embedding Predictive Architecture. First outlined in the CVPR paper, "Self-supervised learning from images with a joint-embedding predictive architecture."
Other
2.75k stars 335 forks source link

Struggling to Train Downstream Classifier #58

Open NolanGC opened 2 months ago

NolanGC commented 2 months ago

Hi,

I'm working on training a downstream classification task from the ImageNet-22k checkpoint. When I use a TinyViT checkpoint, average over the first dimension of output and feed that into a linear classification head, the model trains appropriately. However, if I replace TinyViT with the target encoder of I-JEPA, once again averaging over the first dimension of the final layer and feeding into a linear classification head. However, the model fails to train at all in these conditions. Has anyone been able to successfully train on a downstream task?

Thank you!

prakarsh-sys commented 2 months ago

facing a similar challenge

ChristopherMarais commented 1 month ago

Same issue

FalsoMoralista commented 1 month ago

I have been able to finetune but i'm experiencing some weird behavior during training so i'm still investigating but the model seems to train despite the variance in performance.

Accuracy drops suddenly between epochs 13 and 14. Then starts to recover but finishes with 61% top-1.

I'm finetuning the whole model (Imagenet-22k pretrained weights) over the PlantNet300k dataset with 8 gpus, batch size 64 and 4 gradient accumulation iterations. I'm also using the original warmup scheduler with start lr = 1e-4, lr = 7.5e-4 and final lr = 1e-6.

INFO:root:Epoch 10
INFO:root:[10,     0/  476] train_loss: 4.497 [wd: 7.73e-02] [lr: 6.85e-04] [mem: 4.92e+04] (656.1 ms)
INFO:root:[10,     0] grad_stats: [3.85e+03 1.13e+03] (1.97e+02, 2.23e+04)
INFO:root:[10,   100/  476] train_loss: 4.049 [wd: 7.85e-02] [lr: 6.99e-04] [mem: 4.92e+04] (617.4 ms)
INFO:root:[10,   100] grad_stats: [3.51e+03 9.40e+02] (1.87e+02, 2.35e+04)
INFO:root:[10,   200/  476] train_loss: 4.077 [wd: 7.98e-02] [lr: 7.12e-04] [mem: 4.92e+04] (618.4 ms)
INFO:root:[10,   200] grad_stats: [2.67e+03 8.93e+02] (1.71e+02, 2.19e+04)
INFO:root:[10,   300/  476] train_loss: 4.072 [wd: 8.11e-02] [lr: 7.26e-04] [mem: 4.92e+04] (619.4 ms)
INFO:root:[10,   300] grad_stats: [5.57e+03 1.16e+03] (2.16e+02, 2.53e+04)
INFO:root:[10,   400/  476] train_loss: 4.086 [wd: 8.24e-02] [lr: 7.40e-04] [mem: 4.92e+04] (619.8 ms)
INFO:root:[10,   400] grad_stats: [4.33e+03 7.79e+02] (2.26e+02, 2.32e+04)
INFO:root:avg. train_loss 4.084
INFO:root:avg. test_loss 1.884 avg. Accuracy@1 60.443 - avg. Accuracy@5 82.656
INFO:root:Loss 4.17664
INFO:root:Epoch 11
INFO:root:[11,     0/  476] train_loss: 3.554 [wd: 8.34e-02] [lr: 7.50e-04] [mem: 4.92e+04] (633.5 ms)
INFO:root:[11,     0] grad_stats: [5.31e+03 9.00e+02] (1.91e+02, 2.85e+04)
INFO:root:[11,   100/  476] train_loss: 3.931 [wd: 8.48e-02] [lr: 7.50e-04] [mem: 4.92e+04] (617.2 ms)
INFO:root:[11,   100] grad_stats: [5.04e+03 7.60e+02] (1.69e+02, 2.18e+04)
INFO:root:[11,   200/  476] train_loss: 3.949 [wd: 8.62e-02] [lr: 7.50e-04] [mem: 4.92e+04] (619.0 ms)
INFO:root:[11,   200] grad_stats: [5.25e+03 9.54e+02] (2.42e+02, 3.73e+04)
INFO:root:[11,   300/  476] train_loss: 3.965 [wd: 8.76e-02] [lr: 7.50e-04] [mem: 4.92e+04] (619.7 ms)
INFO:root:[11,   300] grad_stats: [3.31e+03 9.12e+02] (1.85e+02, 2.19e+04)
INFO:root:[11,   400/  476] train_loss: 3.951 [wd: 8.91e-02] [lr: 7.49e-04] [mem: 4.92e+04] (620.1 ms)
INFO:root:[11,   400] grad_stats: [4.00e+03 8.88e+02] (1.78e+02, 2.16e+04)
INFO:root:avg. train_loss 3.957
INFO:root:avg. test_loss 1.724 avg. Accuracy@1 63.073 - avg. Accuracy@5 84.583
INFO:root:Loss 4.15005
INFO:root:Epoch 12
INFO:root:[12,     0/  476] train_loss: 4.433 [wd: 9.02e-02] [lr: 7.49e-04] [mem: 4.92e+04] (645.0 ms)
INFO:root:[12,     0] grad_stats: [4.14e+03 9.29e+02] (2.38e+02, 2.13e+04)
INFO:root:[12,   100/  476] train_loss: 3.943 [wd: 9.17e-02] [lr: 7.48e-04] [mem: 4.92e+04] (617.7 ms)
INFO:root:[12,   100] grad_stats: [3.63e+03 8.86e+02] (1.80e+02, 2.30e+04)
INFO:root:[12,   200/  476] train_loss: 3.904 [wd: 9.32e-02] [lr: 7.48e-04] [mem: 4.92e+04] (619.4 ms)
INFO:root:[12,   200] grad_stats: [4.99e+03 7.64e+02] (2.01e+02, 2.53e+04)
INFO:root:[12,   300/  476] train_loss: 3.887 [wd: 9.47e-02] [lr: 7.47e-04] [mem: 4.92e+04] (619.8 ms)
INFO:root:[12,   300] grad_stats: [3.52e+03 9.62e+02] (1.92e+02, 2.12e+04)
INFO:root:[12,   400/  476] train_loss: 3.883 [wd: 9.63e-02] [lr: 7.46e-04] [mem: 4.92e+04] (620.0 ms)
INFO:root:[12,   400] grad_stats: [5.45e+03 1.10e+03] (1.86e+02, 2.32e+04)
INFO:root:avg. train_loss 3.874
INFO:root:avg. test_loss 1.557 avg. Accuracy@1 66.641 - avg. Accuracy@5 87.734
INFO:root:Loss 3.66831
INFO:root:Epoch 13
INFO:root:[13,     0/  476] train_loss: 4.735 [wd: 9.74e-02] [lr: 7.45e-04] [mem: 4.92e+04] (632.2 ms)
INFO:root:[13,     0] grad_stats: [4.15e+03 1.18e+03] (2.15e+02, 2.48e+04)
INFO:root:[13,   100/  476] train_loss: 3.768 [wd: 9.90e-02] [lr: 7.44e-04] [mem: 4.92e+04] (617.9 ms)
INFO:root:[13,   100] grad_stats: [4.05e+03 9.09e+02] (2.22e+02, 2.19e+04)
INFO:root:[13,   200/  476] train_loss: 3.823 [wd: 1.01e-01] [lr: 7.43e-04] [mem: 4.92e+04] (619.1 ms)
INFO:root:[13,   200] grad_stats: [4.47e+03 8.63e+02] (1.74e+02, 2.11e+04)
INFO:root:[13,   300/  476] train_loss: 3.829 [wd: 1.02e-01] [lr: 7.42e-04] [mem: 4.92e+04] (619.4 ms)
INFO:root:[13,   300] grad_stats: [4.80e+03 9.70e+02] (1.78e+02, 2.27e+04)
INFO:root:[13,   400/  476] train_loss: 3.802 [wd: 1.04e-01] [lr: 7.41e-04] [mem: 4.92e+04] (619.5 ms)
INFO:root:[13,   400] grad_stats: [5.73e+03 1.03e+03] (2.28e+02, 2.34e+04)
INFO:root:avg. train_loss 3.793
INFO:root:avg. test_loss 1.476 avg. Accuracy@1 68.177 - avg. Accuracy@5 88.490
INFO:root:Loss 3.93889
INFO:root:Epoch 14
INFO:root:[14,     0/  476] train_loss: 4.459 [wd: 1.05e-01] [lr: 7.40e-04] [mem: 4.92e+04] (630.1 ms)
INFO:root:[14,     0] grad_stats: [5.01e+03 1.10e+03] (2.14e+02, 2.21e+04)
INFO:root:[14,   100/  476] train_loss: 3.775 [wd: 1.07e-01] [lr: 7.38e-04] [mem: 4.92e+04] (616.9 ms)
INFO:root:[14,   100] grad_stats: [4.08e+03 7.92e+02] (2.00e+02, 1.98e+04)
INFO:root:[14,   200/  476] train_loss: 3.788 [wd: 1.09e-01] [lr: 7.37e-04] [mem: 4.92e+04] (618.7 ms)
INFO:root:[14,   200] grad_stats: [5.10e+03 9.25e+02] (2.72e+02, 2.31e+04)
INFO:root:[14,   300/  476] train_loss: 3.955 [wd: 1.10e-01] [lr: 7.35e-04] [mem: 4.92e+04] (619.3 ms)
INFO:root:[14,   300] grad_stats: [1.87e+02 2.24e+02] (9.74e+00, 2.87e+04)
INFO:root:[14,   400/  476] train_loss: 4.392 [wd: 1.12e-01] [lr: 7.33e-04] [mem: 4.92e+04] (619.0 ms)
INFO:root:[14,   400] grad_stats: [2.46e+02 7.85e+01] (9.80e+00, 3.03e+04)
INFO:root:avg. train_loss 4.592
INFO:root:avg. test_loss 5.392 avg. Accuracy@1 3.932 - avg. Accuracy@5 13.177
INFO:root:Loss 5.61791
INFO:root:Epoch 15
INFO:root:[15,     0/  476] train_loss: 5.663 [wd: 1.13e-01] [lr: 7.32e-04] [mem: 4.92e+04] (632.5 ms)
INFO:root:[15,     0] grad_stats: [3.15e+02 7.60e+01] (6.88e+00, 2.38e+04)
INFO:root:[15,   100/  476] train_loss: 5.633 [wd: 1.15e-01] [lr: 7.30e-04] [mem: 4.92e+04] (615.1 ms)
INFO:root:[15,   100] grad_stats: [9.57e+02 8.18e+01] (1.83e+01, 3.07e+04)
INFO:root:[15,   200/  476] train_loss: 5.624 [wd: 1.17e-01] [lr: 7.28e-04] [mem: 4.92e+04] (616.4 ms)
INFO:root:[15,   200] grad_stats: [1.19e+03 6.76e+01] (1.11e+01, 2.71e+04)
INFO:root:[15,   300/  476] train_loss: 5.613 [wd: 1.19e-01] [lr: 7.25e-04] [mem: 4.92e+04] (617.0 ms)
INFO:root:[15,   300] grad_stats: [1.15e+03 7.75e+01] (1.24e+01, 2.87e+04)
INFO:root:[15,   400/  476] train_loss: 5.596 [wd: 1.21e-01] [lr: 7.23e-04] [mem: 4.92e+04] (617.3 ms)
INFO:root:[15,   400] grad_stats: [2.85e+03 1.16e+02] (3.21e+01, 3.27e+04)
INFO:root:avg. train_loss 5.593
INFO:root:avg. test_loss 5.094 avg. Accuracy@1 8.021 - avg. Accuracy@5 21.510
INFO:root:Loss 5.61608
INFO:root:Epoch 16
INFO:root:[16,     0/  476] train_loss: 5.488 [wd: 1.22e-01] [lr: 7.21e-04] [mem: 4.92e+04] (635.4 ms)
INFO:root:[16,     0] grad_stats: [9.14e+02 8.01e+01] (1.96e+01, 2.63e+04)
INFO:root:[16,   100/  476] train_loss: 5.529 [wd: 1.24e-01] [lr: 7.19e-04] [mem: 4.92e+04] (618.4 ms)
INFO:root:[16,   100] grad_stats: [1.41e+03 8.73e+01] (2.12e+01, 2.58e+04)
INFO:root:[16,   200/  476] train_loss: 5.514 [wd: 1.26e-01] [lr: 7.17e-04] [mem: 4.92e+04] (618.0 ms)
INFO:root:[16,   200] grad_stats: [1.66e+03 1.16e+02] (2.96e+01, 2.22e+04)
INFO:root:[16,   300/  476] train_loss: 5.494 [wd: 1.28e-01] [lr: 7.14e-04] [mem: 4.92e+04] (618.1 ms)
INFO:root:[16,   300] grad_stats: [2.46e+03 1.82e+02] (3.82e+01, 2.75e+04)
INFO:root:[16,   400/  476] train_loss: 5.475 [wd: 1.30e-01] [lr: 7.11e-04] [mem: 4.92e+04] (618.2 ms)
INFO:root:[16,   400] grad_stats: [2.29e+03 1.85e+02] (3.71e+01, 2.88e+04)
INFO:root:avg. train_loss 5.463
INFO:root:avg. test_loss 4.559 avg. Accuracy@1 13.906 - avg. Accuracy@5 33.620
INFO:root:Loss 5.28610
INFO:root:Epoch 17
INFO:root:[17,     0/  476] train_loss: 5.542 [wd: 1.31e-01] [lr: 7.09e-04] [mem: 4.92e+04] (630.0 ms)
INFO:root:[17,     0] grad_stats: [2.02e+03 1.89e+02] (4.37e+01, 2.84e+04)
INFO:root:[17,   100/  476] train_loss: 5.388 [wd: 1.33e-01] [lr: 7.06e-04] [mem: 4.92e+04] (615.8 ms)
INFO:root:[17,   100] grad_stats: [2.52e+03 2.31e+02] (4.84e+01, 3.09e+04)
INFO:root:[17,   200/  476] train_loss: 5.366 [wd: 1.35e-01] [lr: 7.03e-04] [mem: 4.92e+04] (617.0 ms)
INFO:root:[17,   200] grad_stats: [2.03e+03 1.88e+02] (5.36e+01, 2.88e+04)
INFO:root:[17,   300/  476] train_loss: 5.354 [wd: 1.37e-01] [lr: 7.00e-04] [mem: 4.92e+04] (617.6 ms)
INFO:root:[17,   300] grad_stats: [2.76e+03 2.69e+02] (6.21e+01, 2.96e+04)
INFO:root:[17,   400/  476] train_loss: 5.337 [wd: 1.39e-01] [lr: 6.97e-04] [mem: 4.92e+04] (618.0 ms)
INFO:root:[17,   400] grad_stats: [3.61e+03 3.98e+02] (1.36e+02, 4.85e+04)
INFO:root:avg. train_loss 5.328
INFO:root:avg. test_loss 4.345 avg. Accuracy@1 17.656 - avg. Accuracy@5 38.385
INFO:root:Loss 5.45324
INFO:root:Epoch 18
INFO:root:[18,     0/  476] train_loss: 5.148 [wd: 1.41e-01] [lr: 6.95e-04] [mem: 4.92e+04] (640.6 ms)
INFO:root:[18,     0] grad_stats: [7.65e+03 3.98e+02] (1.67e+02, 5.79e+04)
INFO:root:[18,   100/  476] train_loss: 5.242 [wd: 1.43e-01] [lr: 6.92e-04] [mem: 4.92e+04] (616.3 ms)
INFO:root:[18,   100] grad_stats: [3.11e+03 5.56e+02] (1.17e+02, 5.26e+04)
INFO:root:[18,   200/  476] train_loss: 5.248 [wd: 1.45e-01] [lr: 6.88e-04] [mem: 4.92e+04] (617.5 ms)
INFO:root:[18,   200] grad_stats: [4.31e+03 4.79e+02] (1.39e+02, 5.41e+04)
INFO:root:[18,   300/  476] train_loss: 5.238 [wd: 1.47e-01] [lr: 6.85e-04] [mem: 4.92e+04] (617.9 ms)
INFO:root:[18,   300] grad_stats: [4.13e+03 4.89e+02] (1.54e+02, 5.10e+04)
INFO:root:[18,   400/  476] train_loss: 5.228 [wd: 1.49e-01] [lr: 6.81e-04] [mem: 4.92e+04] (618.1 ms)
INFO:root:[18,   400] grad_stats: [5.44e+03 7.96e+02] (2.02e+02, 5.57e+04)
INFO:root:avg. train_loss 5.227
INFO:root:avg. test_loss 4.025 avg. Accuracy@1 22.839 - avg. Accuracy@5 46.458
INFO:root:Loss 5.24773
INFO:root:Epoch 19
INFO:root:[19,     0/  476] train_loss: 4.819 [wd: 1.51e-01] [lr: 6.78e-04] [mem: 4.92e+04] (643.2 ms)
INFO:root:[19,     0] grad_stats: [5.64e+03 7.13e+02] (2.04e+02, 5.72e+04)
INFO:root:[19,   100/  476] train_loss: 5.174 [wd: 1.53e-01] [lr: 6.75e-04] [mem: 4.92e+04] (616.7 ms)
INFO:root:[19,   100] grad_stats: [5.91e+03 7.15e+02] (1.63e+02, 4.58e+04)
INFO:root:[19,   200/  476] train_loss: 5.180 [wd: 1.55e-01] [lr: 6.71e-04] [mem: 4.92e+04] (618.0 ms)
INFO:root:[19,   200] grad_stats: [4.59e+03 5.80e+02] (2.09e+02, 6.19e+04)
INFO:root:[19,   300/  476] train_loss: 5.178 [wd: 1.57e-01] [lr: 6.67e-04] [mem: 4.92e+04] (618.7 ms)
INFO:root:[19,   300] grad_stats: [3.64e+03 5.87e+02] (1.61e+02, 5.17e+04)
INFO:root:[19,   400/  476] train_loss: 5.170 [wd: 1.59e-01] [lr: 6.63e-04] [mem: 4.92e+04] (619.0 ms)
INFO:root:[19,   400] grad_stats: [5.29e+03 6.95e+02] (1.24e+02, 5.52e+04)
INFO:root:avg. train_loss 5.159
INFO:root:avg. test_loss 3.793 avg. Accuracy@1 26.510 - avg. Accuracy@5 50.026
INFO:root:Loss 5.46958
INFO:root:Epoch 20
INFO:root:[20,     0/  476] train_loss: 4.489 [wd: 1.61e-01] [lr: 6.60e-04] [mem: 4.92e+04] (629.6 ms)
INFO:root:[20,     0] grad_stats: [7.00e+03 9.99e+02] (2.63e+02, 5.66e+04)
INFO:root:[20,   100/  476] train_loss: 5.087 [wd: 1.63e-01] [lr: 6.56e-04] [mem: 4.92e+04] (617.4 ms)
INFO:root:[20,   100] grad_stats: [8.41e+03 8.56e+02] (2.36e+02, 5.00e+04)
INFO:root:[20,   200/  476] train_loss: 5.087 [wd: 1.65e-01] [lr: 6.52e-04] [mem: 4.92e+04] (618.1 ms)
INFO:root:[20,   200] grad_stats: [3.79e+03 6.94e+02] (1.65e+02, 5.14e+04)
INFO:root:[20,   300/  476] train_loss: 5.084 [wd: 1.67e-01] [lr: 6.48e-04] [mem: 4.92e+04] (618.5 ms)
INFO:root:[20,   300] grad_stats: [8.43e+03 7.70e+02] (2.28e+02, 5.14e+04)
INFO:root:[20,   400/  476] train_loss: 5.214 [wd: 1.69e-01] [lr: 6.44e-04] [mem: 4.92e+04] (618.3 ms)
INFO:root:[20,   400] grad_stats: [2.92e+03 1.13e+02] (3.17e+01, 5.47e+04)
INFO:root:avg. train_loss 5.269
INFO:root:avg. test_loss 5.074 avg. Accuracy@1 7.344 - avg. Accuracy@5 20.964
INFO:root:Loss 5.59873
INFO:root:Epoch 21
INFO:root:[21,     0/  476] train_loss: 5.450 [wd: 1.71e-01] [lr: 6.40e-04] [mem: 4.92e+04] (629.3 ms)
INFO:root:[21,     0] grad_stats: [2.59e+03 2.03e+02] (4.63e+01, 5.46e+04)
INFO:root:[21,   100/  476] train_loss: 5.513 [wd: 1.73e-01] [lr: 6.36e-04] [mem: 4.92e+04] (615.8 ms)
INFO:root:[21,   100] grad_stats: [4.34e+03 2.78e+02] (9.16e+01, 5.81e+04)
INFO:root:[21,   200/  476] train_loss: 5.493 [wd: 1.75e-01] [lr: 6.31e-04] [mem: 4.92e+04] (617.9 ms)
INFO:root:[21,   200] grad_stats: [3.36e+03 2.82e+02] (9.63e+01, 5.41e+04)
INFO:root:[21,   300/  476] train_loss: 5.457 [wd: 1.78e-01] [lr: 6.27e-04] [mem: 4.92e+04] (618.9 ms)
INFO:root:[21,   300] grad_stats: [6.47e+03 2.78e+02] (1.51e+02, 5.20e+04)
INFO:root:[21,   400/  476] train_loss: 5.418 [wd: 1.80e-01] [lr: 6.22e-04] [mem: 4.92e+04] (619.4 ms)
INFO:root:[21,   400] grad_stats: [6.21e+03 3.90e+02] (1.77e+02, 5.47e+04)
INFO:root:avg. train_loss 5.394
INFO:root:avg. test_loss 4.142 avg. Accuracy@1 19.609 - avg. Accuracy@5 41.302
INFO:root:Loss 5.35757
INFO:root:Epoch 22
INFO:root:[22,     0/  476] train_loss: 4.894 [wd: 1.82e-01] [lr: 6.19e-04] [mem: 4.92e+04] (644.9 ms)
INFO:root:[22,     0] grad_stats: [5.48e+03 5.98e+02] (1.97e+02, 5.11e+04)
INFO:root:[22,   100/  476] train_loss: 5.257 [wd: 1.84e-01] [lr: 6.14e-04] [mem: 4.92e+04] (617.6 ms)
INFO:root:[22,   100] grad_stats: [7.61e+03 5.17e+02] (1.84e+02, 5.37e+04)
INFO:root:[22,   200/  476] train_loss: 5.225 [wd: 1.86e-01] [lr: 6.09e-04] [mem: 4.92e+04] (619.0 ms)
INFO:root:[22,   200] grad_stats: [4.16e+03 5.09e+02] (1.94e+02, 4.94e+04)
INFO:root:[22,   300/  476] train_loss: 5.217 [wd: 1.88e-01] [lr: 6.04e-04] [mem: 4.92e+04] (619.4 ms)
INFO:root:[22,   300] grad_stats: [4.94e+03 5.14e+02] (2.17e+02, 5.49e+04)
INFO:root:[22,   400/  476] train_loss: 5.201 [wd: 1.91e-01] [lr: 5.99e-04] [mem: 4.92e+04] (619.4 ms)
INFO:root:[22,   400] grad_stats: [7.08e+03 7.27e+02] (2.60e+02, 5.48e+04)
INFO:root:avg. train_loss 5.187
INFO:root:avg. test_loss 3.926 avg. Accuracy@1 24.401 - avg. Accuracy@5 47.578
INFO:root:Loss 5.40776
INFO:root:Epoch 23
INFO:root:[23,     0/  476] train_loss: 4.989 [wd: 1.92e-01] [lr: 5.96e-04] [mem: 4.92e+04] (619.9 ms)
INFO:root:[23,     0] grad_stats: [5.67e+03 6.02e+02] (2.22e+02, 5.28e+04)
INFO:root:[23,   100/  476] train_loss: 5.093 [wd: 1.95e-01] [lr: 5.91e-04] [mem: 4.92e+04] (619.2 ms)
INFO:root:[23,   100] grad_stats: [4.35e+03 6.59e+02] (2.53e+02, 5.47e+04)
INFO:root:[23,   200/  476] train_loss: 5.101 [wd: 1.97e-01] [lr: 5.85e-04] [mem: 4.92e+04] (620.5 ms)
INFO:root:[23,   200] grad_stats: [6.09e+03 5.33e+02] (2.17e+02, 5.11e+04)
INFO:root:[23,   300/  476] train_loss: 5.100 [wd: 1.99e-01] [lr: 5.80e-04] [mem: 4.92e+04] (620.6 ms)
INFO:root:[23,   300] grad_stats: [5.40e+03 5.50e+02] (1.88e+02, 5.01e+04)
INFO:root:[23,   400/  476] train_loss: 5.095 [wd: 2.01e-01] [lr: 5.75e-04] [mem: 4.92e+04] (620.5 ms)
INFO:root:[23,   400] grad_stats: [2.99e+03 4.93e+02] (1.65e+02, 4.42e+04)
INFO:root:avg. train_loss 5.086
INFO:root:avg. test_loss 3.670 avg. Accuracy@1 27.943 - avg. Accuracy@5 50.911
INFO:root:Loss 5.18562
INFO:root:Epoch 24
INFO:root:[24,     0/  476] train_loss: 5.132 [wd: 2.03e-01] [lr: 5.71e-04] [mem: 4.92e+04] (644.7 ms)
INFO:root:[24,     0] grad_stats: [4.16e+03 5.98e+02] (2.19e+02, 6.29e+04)
INFO:root:[24,   100/  476] train_loss: 5.092 [wd: 2.05e-01] [lr: 5.66e-04] [mem: 4.92e+04] (617.5 ms)
INFO:root:[24,   100] grad_stats: [4.94e+03 6.53e+02] (2.57e+02, 5.87e+04)
INFO:root:[24,   200/  476] train_loss: 5.074 [wd: 2.08e-01] [lr: 5.60e-04] [mem: 4.92e+04] (618.8 ms)
INFO:root:[24,   200] grad_stats: [5.29e+03 6.18e+02] (2.10e+02, 5.37e+04)
INFO:root:[24,   300/  476] train_loss: 5.050 [wd: 2.10e-01] [lr: 5.55e-04] [mem: 4.92e+04] (619.2 ms)
INFO:root:[24,   300] grad_stats: [5.52e+03 9.11e+02] (2.67e+02, 5.14e+04)
INFO:root:[24,   400/  476] train_loss: 5.045 [wd: 2.12e-01] [lr: 5.50e-04] [mem: 4.92e+04] (619.5 ms)
INFO:root:[24,   400] grad_stats: [4.64e+03 6.11e+02] (1.98e+02, 5.56e+04)
INFO:root:avg. train_loss 5.041
INFO:root:avg. test_loss 3.615 avg. Accuracy@1 29.688 - avg. Accuracy@5 53.125
INFO:root:Loss 4.70165
INFO:root:Epoch 25
INFO:root:[25,     0/  476] train_loss: 5.176 [wd: 2.14e-01] [lr: 5.45e-04] [mem: 4.92e+04] (641.1 ms)
INFO:root:[25,     0] grad_stats: [3.23e+03 5.40e+02] (1.66e+02, 5.01e+04)
INFO:root:[25,   100/  476] train_loss: 5.060 [wd: 2.16e-01] [lr: 5.40e-04] [mem: 4.92e+04] (617.8 ms)
INFO:root:[25,   100] grad_stats: [6.08e+03 7.11e+02] (2.26e+02, 5.61e+04)
INFO:root:[25,   200/  476] train_loss: 5.008 [wd: 2.19e-01] [lr: 5.34e-04] [mem: 4.92e+04] (618.6 ms)
INFO:root:[25,   200] grad_stats: [4.92e+03 1.14e+03] (2.04e+02, 5.64e+04)
INFO:root:[25,   300/  476] train_loss: 4.994 [wd: 2.21e-01] [lr: 5.29e-04] [mem: 4.92e+04] (619.2 ms)
INFO:root:[25,   300] grad_stats: [6.11e+03 6.33e+02] (2.56e+02, 5.32e+04)
INFO:root:[25,   400/  476] train_loss: 4.989 [wd: 2.23e-01] [lr: 5.23e-04] [mem: 4.92e+04] (619.7 ms)
INFO:root:[25,   400] grad_stats: [4.88e+03 8.24e+02] (2.98e+02, 5.51e+04)
INFO:root:avg. train_loss 4.982
INFO:root:avg. test_loss 3.469 avg. Accuracy@1 31.068 - avg. Accuracy@5 55.781
INFO:root:Loss 5.30995
bdytx5 commented 1 week ago

Anyone able to find a solution?

FalsoMoralista commented 6 days ago

Other than this weird behavior of accuracy dropping out of nowhere (which i suspect that its something related either to the size of the dataset or learning rate value), i didn't had problems with fine-tuning. I have also compared fine-tuning and linear probing over the intel dataset (https://www.kaggle.com/datasets/puneet6060/intel-image-classification) and didn't had any problems with it.

finetuning code