mlcommons / algorithmic-efficiency

MLCommons Algorithmic Efficiency is a benchmark and competition measuring neural network training speedups due to algorithmic improvements in both training algorithms and models.
https://mlcommons.org/en/groups/research-algorithms/
Apache License 2.0
335 stars 69 forks source link

Add schedule-free adamw submission in JAX #809

Open priyakasimbeg opened 3 weeks ago

priyakasimbeg commented 3 weeks ago

Description

Currently we have been unable to reproduce the schedule free adamw results with JAX. There seem to be differences between the optax implementation of schedule-free adamw and the pytorch submission.

adefazio commented 3 weeks ago

I can help debug any issues here. Do you have any code you can share? If there are issues with the optax jax implementation I want to get it fixed asap.

adefazio commented 2 weeks ago

There are many small differences between the behavior of schedule-free jax wrapper and the original algoperf submission. Some differences I'm aware of:

So overall I expect the jax wrapper version to give as good results on all problems (maybe slightly slower on fastmrI), so if there is a difference it would be from some sort of bug.

priyakasimbeg commented 5 days ago

Hi Aaron! thanks for weighing in on this. I seemed to have missed your messages on this thread.

We have a slightly modified version based on the optax code here: https://github.com/priyakasimbeg/algorithmic-efficiency/blob/compare_schedule_free/tests/test_algorithms/schedule_free_adamw/jax/schedule_free_optax.py. This code adds r and we tested it with 0.75 on our google internal codebase.

I'm working on a test to compare the pytorch and jax implementations side by side on the algoperf github code but the test is still in progress. I can perhaps run a full training run on some of the workloads. But in the meantime feel free to weigh in again if you spot any other differences

adefazio commented 5 days ago

Ok, I take a look and see if I spot any differences.

adefazio commented 5 days ago

It looks like the z buffer my be initialized with zeros: https://github.com/priyakasimbeg/algorithmic-efficiency/blob/5556015054e3dda681e2a25e05a2f217d933453d/tests/test_algorithms/schedule_free_adamw/jax/submission.py#L58C51-L59C1 It needs to be initialized the same as the main parameter buffer. I think this line is a copy-paste error from the Jax version of NAdamW and other methods, where all optimizer state is normally initialized at zero.

Suggestion: you might want to set z on the first call to the main optimizer update, that's what we do in the pytorch version.

adefazio commented 2 days ago

@priyakasimbeg Let me know if that initialization issue was the problem.