Closed ftgreat closed 1 year ago
bmtrain.optim.OptimManager
can do this during the backward step, can it meet your needs?
bmtrain.optim.OptimManager
can do this during the backward step, can it meet your needs?
If NaN found during backward step, loss scale will be adjusted to smaller value. When NaN found in forward step, it maybe implies bad data. Could just skip the current step and no backward (also no loss scale changes)? @Achazwl
Maybe you can use bmt.sum_loss(loss)
to find if any rank has NaN? bmt.sum_loss(loss)
will be the same on each rank, therefore you can skip globally.
Maybe you can use
bmt.sum_loss(loss)
to find if any rank has NaN?bmt.sum_loss(loss)
will be the same on each rank, therefore you can skip globally.
We used this as follows. But encounter one case: when one NaN batch found(already checked and indeed inf and nan tensors found in forward step), reduced_loss is equal to NaN and current global step skipped. However, since this batch rank0 will keep NaN value and it seems sum_loss still holds NaN state since then.
Could you give some advice to debug this case, thanks!
After skipping the current data, is the next batch of data still only rank 0 produce nan? After continuously skipping the next step, is all batches have only rank 0 produce nan and others not?
After skipping the current data, is the next batch of data still only rank 0 produce nan? After continuously skipping the next step, is all batches have only rank 0 produce nan and others not?
After skipping the current data, each rank model output loss(loss variable in above screenshot) of each rank is valid value. but reduced_loss of each rank still NaN. Not sure whether bmt.sum_loss still hold NaN or lm_loss.detach.clone().view(1) has something wrong.
After continuously skipping the next step, all batches skipped because of the same behaviour.
Could you please give some suggestions about above case?thanks
sum_loss will not retain the previous NaN. Is the loss of half type or float32 type now? Use regular Python print
to check if there are extreme values for each rank's loss, such as near the value boundary.
sum_loss will not retain the previous NaN. Is the loss of half type or float32 type now? Use regular Python
whole model & loss are half type. By printing each layer outputs(including intermediate activations), no inf or nan are found. It tells each rank outputs valid loss. However, sum_loss or sum_loss.detach is NaN.
When NaN happens during forward step of one rank, how to skip this global step for all ranks? Thanks.