OpenBMB / BMTrain

Efficient Training (including pre-training and fine-tuning) for Big Models
Apache License 2.0
560 stars 77 forks source link

how to skip one iter for all ranks #108

Closed ftgreat closed 1 year ago

ftgreat commented 1 year ago

When NaN happens during forward step of one rank, how to skip this global step for all ranks? Thanks.

Achazwl commented 1 year ago

bmtrain.optim.OptimManager can do this during the backward step, can it meet your needs?

ftgreat commented 1 year ago

bmtrain.optim.OptimManager can do this during the backward step, can it meet your needs?

If NaN found during backward step, loss scale will be adjusted to smaller value. When NaN found in forward step, it maybe implies bad data. Could just skip the current step and no backward (also no loss scale changes)? @Achazwl

Achazwl commented 1 year ago

Maybe you can use bmt.sum_loss(loss) to find if any rank has NaN? bmt.sum_loss(loss)will be the same on each rank, therefore you can skip globally.

ftgreat commented 1 year ago

Maybe you can use bmt.sum_loss(loss) to find if any rank has NaN? bmt.sum_loss(loss)will be the same on each rank, therefore you can skip globally.

We used this as follows. But encounter one case: when one NaN batch found(already checked and indeed inf and nan tensors found in forward step), reduced_loss is equal to NaN and current global step skipped. However, since this batch rank0 will keep NaN value and it seems sum_loss still holds NaN state since then.

Could you give some advice to debug this case, thanks! image

Achazwl commented 1 year ago

After skipping the current data, is the next batch of data still only rank 0 produce nan? After continuously skipping the next step, is all batches have only rank 0 produce nan and others not?

ftgreat commented 1 year ago

After skipping the current data, is the next batch of data still only rank 0 produce nan? After continuously skipping the next step, is all batches have only rank 0 produce nan and others not?

After skipping the current data, each rank model output loss(loss variable in above screenshot) of each rank is valid value. but reduced_loss of each rank still NaN. Not sure whether bmt.sum_loss still hold NaN or lm_loss.detach.clone().view(1) has something wrong.

After continuously skipping the next step, all batches skipped because of the same behaviour.

ftgreat commented 1 year ago

Could you please give some suggestions about above case?thanks

Achazwl commented 1 year ago

sum_loss will not retain the previous NaN. Is the loss of half type or float32 type now? Use regular Python print to check if there are extreme values for each rank's loss, such as near the value boundary.

ftgreat commented 1 year ago

sum_loss will not retain the previous NaN. Is the loss of half type or float32 type now? Use regular Python print to check if there are extreme values for each rank's loss, such as near the value boundary.

whole model & loss are half type. By printing each layer outputs(including intermediate activations), no inf or nan are found. It tells each rank outputs valid loss. However, sum_loss or sum_loss.detach is NaN.