mlcommons / algorithmic-efficiency

MLCommons Algorithmic Efficiency is a benchmark and competition measuring neural network training speedups due to algorithmic improvements in both training algorithms and models.
https://mlcommons.org/en/groups/research-algorithms/
Apache License 2.0
328 stars 65 forks source link

fix traindiffs #649

Closed chandramouli-sastry closed 8 months ago

chandramouli-sastry commented 8 months ago

Current results from the traindiff tests are as follows (observe that the pytorch logs of fastmri are the same as in the above pic):

=================================================================Testing criteo1tb==================================================================
|        Iter        |     Eval (jax)     |    Eval (torch)    |  Grad Norm (jax)   | Grad Norm (torch)  |  Train Loss (jax)  | Train Loss (torch) |
====================================================================================================================================================
|         0          |      0.67088       |      0.67088       |       1.3018       |       1.3018       |      0.67257       |      0.67257       |
|         1          |       0.6692       |       0.6692       |      1.29663       |      1.29663       |     0.67087996     |      0.67088       |
|         2          |      0.66754       |      0.66754       |      1.29149       |      1.29149       |       0.6692       |       0.6692       |
|         3          |      0.66589       |      0.66589       |      1.28631       |      1.28631       |     0.66753995     |      0.66754       |
|         4          |      0.66423       |      0.66423       |      1.28695       |      1.28695       |      0.66589       |      0.66589       |
|         5          |      0.66259       |      0.66259       |       1.2821       |       1.2821       |      0.66423       |      0.66423       |
|         6          |      0.66097       |      0.66097       |      1.27727       |      1.27727       |     0.66258997     |      0.66259       |
|         7          |      0.65937       |      0.65937       |      1.26699       |      1.26699       |      0.66097       |      0.66097       |
|         8          |      0.65779       |      0.65779       |     1.2631999      |       1.2632       |      0.65937       |      0.65937       |
|         9          |      0.65623       |      0.65623       |      1.25321       |      1.25321       |      0.65779       |      0.65779       |
====================================================================================================================================================
[       OK ] ModelDiffTest.test_workload_criteo1tb
[ RUN      ] ModelDiffTest.test_workload_fastmri
==================================================================Testing fastmri===================================================================
|        Iter        |     Eval (jax)     |    Eval (torch)    |  Grad Norm (jax)   | Grad Norm (torch)  |  Train Loss (jax)  | Train Loss (torch) |
====================================================================================================================================================
|         0          |      0.81652       |      1.11668       |     293615.97      |      1.09796       |     0.81365997     |      1.11789       |
|         1          |      0.81447       |      1.11548       |     15497.995      |      1.09571       |      0.81652       |      1.11668       |
|         2          |      0.81426       |      1.11429       |     5004.3184      |      1.09368       |      0.81447       |      1.11548       |
|         3          |      0.81376       |       1.1131       |      3284.724      |      1.09162       |      0.81426       |      1.11429       |
|         4          |      0.81358       |      1.11191       |     3348.0242      |      1.08919       |      0.81376       |       1.1131       |
|         5          |      0.81294       |      1.11073       |      3152.087      |      1.08695       |      0.81358       |      1.11191       |
|         6          |      0.81295       |      1.10956       |     3772.5212      |      1.08475       |      0.81294       |      1.11073       |
|         7          |      0.81231       |      1.10839       |      6063.978      |      1.08248       |     0.81294996     |      1.10956       |
|         8          |      0.81175       |      1.10722       |     4311.5225      |      1.08031       |      0.81231       |      1.10839       |
|         9          |      0.81125       |      1.10606       |     1400.8125      |      1.07817       |      0.81175       |      1.10722       |
====================================================================================================================================================
[  FAILED  ] ModelDiffTest.test_workload_fastmri
[ RUN      ] ModelDiffTest.test_workload_imagenet_resnet
==============================================================Testing imagenet_resnet===============================================================
|        Iter        |     Eval (jax)     |    Eval (torch)    |  Grad Norm (jax)   | Grad Norm (torch)  |  Train Loss (jax)  | Train Loss (torch) |
====================================================================================================================================================
|         0          |      6.92439       |      6.92492       |     4.5467997      |       4.5468       |      6.95988       |      6.95988       |
|         1          |      6.90589       |      6.90711       |      4.54643       |      4.54642       |      6.93921       |      6.93921       |
|         2          |      6.89025       |      6.89222       |      4.54502       |      4.54502       |      6.91854       |      6.91854       |
|         3          |      6.87629       |       6.8791       |       4.5436       |       4.5436       |     6.8978896      |      6.89789       |
|         4          |      6.86291       |       6.8665       |      4.54214       |      4.54214       |     6.8772497      |      6.87725       |
|         5          |      6.84933       |       6.8537       |      4.54066       |      4.54066       |      6.85662       |      6.85662       |
|         6          |      6.83649       |      6.84166       |      4.53916       |      4.53916       |      6.83601       |      6.83601       |
|         7          |      6.82405       |      6.82978       |      4.53763       |      4.53762       |     6.8154097      |      6.81541       |
|         8          |      6.81047       |      6.81688       |      4.53606       |      4.53606       |      6.79482       |      6.79482       |
|         9          |      6.79586       |       6.8029       |      4.53447       |      4.53446       |      6.77425       |      6.77425       |
====================================================================================================================================================
[       OK ] ModelDiffTest.test_workload_imagenet_resnet
[ RUN      ] ModelDiffTest.test_workload_imagenet_vit
================================================================Testing imagenet_vit================================================================
|        Iter        |     Eval (jax)     |    Eval (torch)    |  Grad Norm (jax)   | Grad Norm (torch)  |  Train Loss (jax)  | Train Loss (torch) |
====================================================================================================================================================
|         0          |      6.89897       |      6.89897       |      2.96483       |      2.96483       |     6.9077597      |      6.90776       |
|         1          |      6.89018       |      6.89018       |      2.96448       |      2.96448       |     6.8989697      |      6.89896       |
|         2          |      6.88139       |      6.88139       |      2.96432       |      2.96432       |     6.8901796      |      6.89018       |
|         3          |       6.8726       |       6.8726       |      2.96436       |      2.96436       |     6.8813896      |      6.88139       |
|         4          |      6.86381       |      6.86381       |     2.9645798      |      2.96458       |     6.8725996      |       6.8726       |
|         5          |      6.85502       |      6.85502       |       2.965        |       2.965        |      6.86381       |      6.86381       |
|         6          |      6.84623       |      6.84623       |       2.9656       |       2.9656       |      6.85502       |      6.85502       |
|         7          |      6.83742       |      6.83743       |      2.96639       |      2.96639       |      6.84623       |      6.84623       |
|         8          |      6.82862       |      6.82862       |      2.96736       |      2.96736       |      6.83742       |      6.83743       |
|         9          |       6.8198       |       6.8198       |      2.96852       |      2.96852       |      6.82862       |      6.82862       |
====================================================================================================================================================
[       OK ] ModelDiffTest.test_workload_imagenet_vit
[ RUN      ] ModelDiffTest.test_workload_librispeech_conformer
===========================================================Testing librispeech_conformer============================================================
|        Iter        |     Eval (jax)     |    Eval (torch)    |  Grad Norm (jax)   | Grad Norm (torch)  |  Train Loss (jax)  | Train Loss (torch) |
====================================================================================================================================================
|         0          |      14.54747      |      14.54745      |     28.267809      |      28.26826      |      15.8916       |      15.89159      |
|         1          |      9.84034       |       9.8415       |      64.25728      |      64.23972      |      14.57298      |      14.57295      |
|         2          |     7.3346796      |      7.33475       |      62.99545      |      63.00708      |      10.16284      |      10.1639       |
|         3          |     7.3132997      |      7.31329       |      17.53098      |      17.53617      |      7.43671       |      7.43681       |
|         4          |     7.3249598      |      7.32494       |      2.25558       |       2.2552       |     7.2687497      |      7.26876       |
|         5          |      7.32605       |      7.32604       |      1.34152       |      1.34181       |      7.26507       |      7.26508       |
|         6          |      7.32504       |      7.32503       |     1.3000699      |      1.29991       |      7.26331       |      7.26332       |
|         7          |      7.32365       |      7.32363       |     1.2951599      |      1.29519       |     7.2616296      |      7.26163       |
|         8          |      7.32218       |      7.32216       |      1.29174       |      1.29187       |     7.2599497      |      7.25996       |
|         9          |     7.3206997      |      7.32069       |      1.28801       |      1.28804       |      7.25829       |      7.25829       |
====================================================================================================================================================
[       OK ] ModelDiffTest.test_workload_librispeech_conformer
[ RUN      ] ModelDiffTest.test_workload_librispeech_deepspeech
===========================================================Testing librispeech_deepspeech===========================================================
|        Iter        |     Eval (jax)     |    Eval (torch)    |  Grad Norm (jax)   | Grad Norm (torch)  |  Train Loss (jax)  | Train Loss (torch) |
====================================================================================================================================================
|         0          |      15.89399      |      15.89381      |     111.24035      |     111.27466      |     15.952009      |      15.95201      |
|         1          |      15.9557       |      15.96647      |     171.42363      |     187.37987      |      14.04443      |      14.04259      |
|         2          |      15.92406      |      15.9933       |      152.6857      |     120.70369      |      9.98144       |      10.01006      |
|         3          |      15.89225      |      15.92346      |     40.099438      |      32.75953      |      8.51132       |      8.45742       |
|         4          |     15.8683195     |      15.94237      |      12.85541      |      19.43173      |      8.233049      |      8.22257       |
|         5          |     15.853029      |      15.92554      |      9.27176       |      10.63069      |      8.15916       |      8.14575       |
|         6          |      15.83867      |      15.91145      |      9.198649      |      11.25572      |      8.10091       |      8.09044       |
|         7          |     15.822149      |      15.91138      |     10.2299795     |      10.06033      |      8.06234       |      8.05373       |
|         8          |      15.80604      |      15.89812      |     11.217959      |      8.19493       |      8.02546       |      8.00401       |
|         9          |     15.807059      |      15.88518      |      9.58568       |      10.37074      |       7.9919       |       7.9664       |
====================================================================================================================================================
[       OK ] ModelDiffTest.test_workload_librispeech_deepspeech
[ RUN      ] ModelDiffTest.test_workload_ogbg
====================================================================Testing ogbg====================================================================
|        Iter        |     Eval (jax)     |    Eval (torch)    |  Grad Norm (jax)   | Grad Norm (torch)  |  Train Loss (jax)  | Train Loss (torch) |
====================================================================================================================================================
|         0          |      0.77482       |      0.77482       |      0.81104       |      0.81104       |      0.77691       |      0.77691       |
|         1          |      0.77573       |      0.77573       |      552.9373      |     552.93719      |     0.77481997     |      0.77482       |
|         2          |      0.77504       |      0.77504       |      0.83438       |      0.83438       |     0.77572995     |      0.77573       |
|         3          |      0.77435       |      0.77435       |     0.83116996     |      0.83117       |      0.77504       |      0.77504       |
|         4          |      0.77365       |      0.77365       |     0.83388996     |      0.83389       |      0.77435       |      0.77435       |
|         5          |      0.77295       |      0.77295       |      0.83667       |      0.83667       |      0.77365       |      0.77365       |
|         6          |      0.77226       |      0.77226       |      0.83329       |      0.83329       |      0.77295       |      0.77295       |
|         7          |      0.77157       |      0.77157       |     0.83036995     |      0.83037       |      0.77226       |      0.77226       |
|         8          |       0.7709       |       0.7709       |     0.82502997     |      0.82503       |     0.77156997     |      0.77157       |
|         9          |      0.77022       |      0.77022       |     0.82449996     |       0.8245       |     0.77089995     |       0.7709       |
====================================================================================================================================================
[       OK ] ModelDiffTest.test_workload_ogbg
[ RUN      ] ModelDiffTest.test_workload_wmt
====================================================================Testing wmt=====================================================================
|        Iter        |     Eval (jax)     |    Eval (torch)    |  Grad Norm (jax)   | Grad Norm (torch)  |  Train Loss (jax)  | Train Loss (torch) |
====================================================================================================================================================
|         0          |      11.75565      |      11.75565      |      12.06393      |      12.06395      |     11.8928995     |      11.8929       |
|         1          |      11.64717      |      11.64717      |     10.703939      |      10.70386      |      11.75565      |      11.75565      |
|         2          |      11.56005      |      11.56005      |       9.5784       |      9.57842       |      11.64716      |      11.64716      |
|         3          |      11.48919      |      11.48919      |      8.62245       |       8.6224       |      11.56005      |      11.56005      |
|         4          |      11.43075      |      11.43075      |     7.8199196      |      7.81998       |      11.48919      |      11.48919      |
|         5          |      11.38191      |      11.38191      |     7.1374598      |      7.13747       |      11.43075      |      11.43075      |
|         6          |      11.34053      |      11.34053      |      6.55883       |      6.55888       |     11.381909      |      11.38191      |
|         7          |      11.30505      |      11.30505      |      6.06523       |       6.0652       |     11.340529      |      11.34052      |
|         8          |      11.27429      |      11.27429      |      5.63974       |      5.63975       |      11.30505      |      11.30505      |
|         9          |      11.2474       |      11.2474       |     5.2683797      |       5.2684       |      11.27429      |      11.27429      |
====================================================================================================================================================
[       OK ] ModelDiffTest.test_workload_wmt
======================================================================
FAIL: test_workload_fastmri (__main__.ModelDiffTest)
ModelDiffTest.test_workload_fastmri
test_workload_fastmri(workload='fastmri')
----------------------------------------------------------------------
Traceback (most recent call last):
  File "/home/chandramouli_sastry/.local/lib/python3.8/site-packages/absl/testing/parameterized.py", line 318, in bound_param_test
    return test_method(self, **testcase_params)
  File "/home/chandramouli_sastry/algorithmic-efficiency/tests/test_traindiffs.py", line 117, in test_workload
    self.assertTrue(  # grad_norms
AssertionError: False is not true

----------------------------------------------------------------------
Ran 8 tests in 849.021s

FAILED (failures=1)
github-actions[bot] commented 8 months ago

MLCommons CLA bot All contributors have signed the MLCommons CLA ✍️ ✅