Liuhong99 / Sophia

The official implementation of “Sophia: A Scalable Stochastic Second-order Optimizer for Language Model Pre-training”
MIT License
937 stars 54 forks source link

About `bs`, `hessian_update`, and extending to tasks beyond CLM #27

Closed KeremTurgutlu closed 1 year ago

KeremTurgutlu commented 1 year ago

Thanks a lot for this amazing work and open sourcing the code for it. I am always very excited about new methods which democratizes AI especially when you can speed up LLM training by this much :)

I have integrated the sophia optimizer to my training script which is based off of HF accelerate deepspeed training script, it also includes some custom modification such as triton based flash attention, fused loss function, etc.. for further improving the MFU - which I am able to achieve around ~50% for large models and around ~40% for smaller models (e.g. 280M).

So far I have made some early tests to see the throughput of Sophia vs AdamW. Tested on a single A100 80GB.

Model Size Optimizer seq / sec
1B AdamW ~20.3
1B Sophia ~17.7
280M AdamW ~60
280M Sophia ~47

In summary I am observing -10% to -15% in throughput, is this expected? In paper I think a 5% overhead was mentioned but not this much probably. The main problem is that Sophia seems to increase GPU memory usage and max batch size I can fit needs to drop from 52 -> 47 for the 280M model. Do you have any recommendations about this?

I have few follow up questions:

1) For the bs parameter this is my understanding (assuming multi-GPU/process training):

total_batch_size = num_accumulation_steps * batch_size_per_gpu
bs = total_batch_size * block_size # (aka sequence length)

Is this correct?

2) About hessian update, I was wondering if we should handle padding when computing the loss by ignoring padding tokens? Does the implementation assumes packed inputs? Currently ignore_index doesn't have any effect.

# Gauss-Newton-Bartlett(θ)
outputs = model(**batch)
logits = outputs.logits
y_sample = torch.distributions.Categorical(logits=logits).sample()
# ignore_index=-1 will never be used, categories are non-negative
loss = torch.nn.functional.cross_entropy(logits.view(-1, logits.size(-1)), y_sample.view(-1),
                                          ignore_index=-1)

3) Is there anything need to be changed in hessian update loss computation if a different task such as prefix LM is used instead of causal LM? Prefix LM and other similar tasks (T5) wouldn't compute loss on tokens which have bidirectional attention. So if someone wants to use Sophia for such task should we ignore hessian update loss (Gauss-Newton-Bartlett) computation on those tokens by using masking and ignore_index? This is similar to question 2 above.

Liuhong99 commented 1 year ago

Thanks for the questions!

  1. yea, the bs is number of sequences * sequence length
  2. ignore_index in this code base does not take effect. As you said, when we have padding tokens in instruction tuning, or MLM, this can be useful.
  3. Nothing special is needed for T5 or Bert. We just need the same loss for Hessian estimates and gradient computation.