MLCommons Algorithmic Efficiency is a benchmark and competition measuring neural network training speedups due to algorithmic improvements in both training algorithms and models.
The fastmri jax comparison is failing and i spent several hours trying to find what's going on but was not able to make any progress on that. I also tried upgrading the jax version but it doesn't help either. This used to pass before as can be seen from this snap in https://github.com/mlcommons/algorithmic-efficiency/pull/314#discussion_r1106357658
Current results from the traindiff tests are as follows (observe that the pytorch logs of fastmri are the same as in the above pic):
I wasn't able to run the github actions runner because of some other issues but I think this should fix the major issues with traindiffs test; if not, I'll request @priyakasimbeg for help in further debugging.
Current results from the traindiff tests are as follows (observe that the pytorch logs of fastmri are the same as in the above pic):