facebookresearch / schedule_free

Schedule-Free Optimization in PyTorch
Apache License 2.0
1.9k stars 64 forks source link

Usage with Torch's Autocast and GradScaler #8

Closed fhahlbohm closed 6 months ago

fhahlbohm commented 7 months ago

I am trying to make AdamWScheduleFree work with an optimization pipeline that uses https://pytorch.org/docs/stable/amp.html

More specifically, forward passes use torch.cuda.amp.autocast and gradients are scaled using torch.cuda.amp.GradScaler.

Here is some pseudocode for a training iteration:


def train_itr(training_sample):
  self.model.train()
  self.optimizer.train()  # self.optimizer is AdamWScheduleFree 
  self.optimizer.zero_grad()
  with torch.cuda.amp.autocast(enabled=True):
      outputs = self.model(training_sample)
      loss = self.loss(outputs, training_sample)
  self.grad_scaler.scale(loss).backward()
  self.grad_scaler.step(self.optimizer)
  self.grad_scaler.update()
  # self.scheduler.step()

Sadly, model outputs outside of training seem to contain NaN values. I saw the README states that additional steps might be necessary for my use case.

Is there an established way of doing this?

adefazio commented 7 months ago

This is something that's unclear to me, I don't understand the internals of the gradscaler implementation. I think you can set cache_enabled=False? Please let me know if that works for you.

bhack commented 7 months ago

I am also interested in this.

fhahlbohm commented 7 months ago

I did not have the time to test the suggested solution yet. Will post an update as soon as I find the time for it!

bhack commented 7 months ago

Have you tried with torch.nn.utils.clip_grad_norm_?

yxchng commented 7 months ago

does the current optimizer only work for fp32 training? or does it also work with amp?

adefazio commented 7 months ago

So far in our experiments it seems to be working correctly with GradScaler and autocast in our experiments using the nanogpt codebase.

zhulinchng commented 4 months ago

@fhahlbohm were you able to find a solution for this? I also got NaN losses when trained on multiple GPUs.