Open ashors1 opened 1 year ago
@ashors1 sorry for the late review, could rebase to head? i want to import it and run some internal CI, thanks!
There's quite a few redundant whitespaces. Could you run some python linter to remove those?
@zhangqiaorjc is there a reason this has been approved by not merged yet?
Refactoring to allow gradient clipping to be performed on full batch rather than subbatches when using
ShardedStaticAccumulator
. Note that this refactor allows us to maintain support forenable_skip_step_on_gradient_anomalies
and requiresx+1
grad norm calculations per global batch when usingShardedStaticAccumulator
withx
subbatches (once per subbatch to determine whether step should be skipped, once when applying gradient clipping in base optimizer update) and requires one grad clip per global batch.This PR should be taken together with the corresponding Praxis PR.