poets-ai / elegy

A High Level API for Deep Learning in JAX
https://poets-ai.github.io/elegy/
MIT License
469 stars 32 forks source link

[Feature Request] Gradient Accumulation #225

Closed lkhphuc closed 2 years ago

lkhphuc commented 2 years ago

Is your feature request related to a problem? Please describe. Adding support for gradient accumulation. I think this is a simple enough modification that it should be built in as a flag in the fit method.

Describe the solution you'd like

model.fit(inputs, batch_size=16,  batch_grad_accum=2, epochs=10)

Describe alternatives you've considered Doesn't make a lot of sense to implement via Callback and should be common enough to not need to use Low-level API.

lkhphuc commented 2 years ago

This can be done quite easily with Optax's MultiStep wrapper. Very neat, love it.

optmizer = optax.MultiStep(optax.sgd(3e-4), every_k_schedule=k).gradient_transformation()