google / paxml

Pax is a Jax-based machine learning framework for training large scale models. Pax allows for advanced and fully configurable experimentation and parallelization, and has demonstrated industry leading model flop utilization rates.
Apache License 2.0
458 stars 69 forks source link

Perform gradient clipping on global batch when using gradient accumulation #9

Open ashors1 opened 1 year ago

ashors1 commented 1 year ago

Refactoring to allow gradient clipping to be performed on full batch rather than subbatches when using ShardedStaticAccumulator. Note that this refactor allows us to maintain support for enable_skip_step_on_gradient_anomalies and requires x+1 grad norm calculations per global batch when using ShardedStaticAccumulator with x subbatches (once per subbatch to determine whether step should be skipped, once when applying gradient clipping in base optimizer update) and requires one grad clip per global batch.

This PR should be taken together with the corresponding Praxis PR.

zhangqiaorjc commented 1 year ago

@ashors1 sorry for the late review, could rebase to head? i want to import it and run some internal CI, thanks!

zhangqiaorjc commented 1 year ago

There's quite a few redundant whitespaces. Could you run some python linter to remove those?

nluehr commented 1 year ago

@zhangqiaorjc is there a reason this has been approved by not merged yet?