jzhang38 / TinyLlama

The TinyLlama project is an open endeavor to pretrain a 1.1B Llama model on 3 trillion tokens.
Apache License 2.0
7.64k stars 446 forks source link

Gradient problem when the number of devices is 1 #88

Closed SivilTaram closed 10 months ago

SivilTaram commented 10 months ago

First thanks for all of the efforts you have done for the TinyLlama project - it's awesome!

Recently I found a spurious problem. When there is only one card, the training gradient would disappear. Is it expected? Is the model training on one card currently not supported?

Thanks!

jzhang38 commented 10 months ago

Can you share more details including the error message?

radarFudan commented 10 months ago

I ran the script on a single node using the following command. The training loss will not descent. The evaluated gradient norm is zero.

lightning run model \
    --node-rank=0  \
    --accelerator=cuda \
    --devices=1 \
    --num-nodes=1 \
    pretrain/tinyllama.py --devices 1 --train_data_dir data/slim_star  --val_data_dir data/slim_star

Gradient norm:

if not is_accumulating:
  total_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=grad_clip)
  # Print the gradient norm
  print("Gradient norm: ", total_norm)

  fabric.clip_gradients(model, optimizer, max_norm=grad_clip)
jzhang38 commented 10 months ago

lit-gpt also has this issue: https://github.com/Lightning-AI/lit-gpt/issues/689 Probably an issue with lightning Fabric, or some toggles need to be on. (or should I call it an issue? Cz Fabric is designed for multi-gpu)

Update: @ChaosCodes managed to reproduce the error on a single GPU.

with fabric.no_ backward sync(model,enabled=is accumulating): logits = model(input ids)

The logits in the line produce 0 always produce 0.

ChaosCodes commented 10 months ago

Hi, if you want you train your code with 1 gpu, you perhaps need to set empty_init=Flase here. But I have no idea yet why it will lead to the gradient problem.

This should fix the problem for now https://github.com/jzhang38/TinyLlama/commit/782f1824dd6dae05adb4dcf1d784259006a9b1f4