ironjr / grokfast

Official repository for the paper "Grokfast: Accelerated Grokking by Amplifying Slow Gradients"
https://arxiv.org/abs/2405.20233
MIT License
476 stars 39 forks source link

How to use Grokfast with FP16 mixed precision training? #10

Open peterjc123 opened 2 months ago

peterjc123 commented 2 months ago

Hi, I'm trying out Grokfast in a LLM scenario. Mixed precision training is a commonly-used technique to save GPU memory usage and speedup training. The following code is an example for FP16 training.

scaler = GradScaler()

for epoch in epochs:
    for input, target in data:
        optimizer.zero_grad()
        with autocast(device_type='cuda', dtype=torch.float16):
            output = model(input)
            loss = loss_fn(output, target)
        scaler.scale(loss).backward()

        # Unscales the gradients of optimizer's assigned params in-place
        scaler.unscale_(optimizer)

        # Since the gradients of optimizer's assigned params are unscaled, clips as usual:
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm)

        # optimizer's gradients are already unscaled, so scaler.step does not unscale them,
        # although it still skips optimizer.step() if the gradients contain infs or NaNs.
        scaler.step(optimizer)

        # Updates the scale for next iteration.
        scaler.update()

The question is where should I put grads = gradfilter_ema(model, grads)? I tried to put this between scale and unscale, but it doesn't work, the loss scale just explodes.

damian0815 commented 2 months ago

similar issue here - when I put grads = gradfilter_ema(model, grads) after the call to scaler.unscale_(optimizer) the scale goes to 0 and i get nans for the step loss

ironjr commented 1 month ago

Thank you for the valuable report! This should be because of the increased gradient norm due to the added low-pass filtered gradient.

The code here is basically for the proof-of-concept demonstration of acceleration of grokking in the previously known scenarios. For larger models, I suspect there should be more sophisticated control of the step size of the gradient updates, especially with mixed precision training you have mentioned. I will revise the code to add more compatibility to train larger models in the next version.