Open gramesh-amd opened 1 week ago
@ZhiyuLi-goog thanks again for your help with other issues. Do you see any problems with the config or know why the loss is much higher?
I have never tried on GPU. To narrow down the root cause, could you try with normal attention?
attention: "dot_product"
with attention: "dot_product" : completed step: 4000, seconds: 91.772, TFLOP/s/device: 24.021, Tokens/s/device: 22.316, total_weights: 65504, loss: 7.644, perplexity: 2088.295 To see full metrics 'tensorboard --logdir=/ckpts/paxml/gpt3-conversion/gpt3-conversion/tensorboard/' completed step: 4001, seconds: 39.677, TFLOP/s/device: 277.795, Tokens/s/device: 258.083, total_weights: 327520, loss: 7.638, perplexity: 2076.376 completed step: 4002, seconds: 39.883, TFLOP/s/device: 276.359, Tokens/s/device: 256.749, total_weights: 327520, loss: 7.646, perplexity: 2092.290
I get similar loss as before
Oh, could you try something like
python3 MaxText/train.py MaxText/configs/base.yml run_name="${RUNNAME}" model_name=gpt3-175b
instead of changing the base.yml? You can find the exact model yaml setup gpt3-175b.yml and there's some more setup for gpt3-175b.
# these flags might be relevant to output results
logits_via_embedding: True
normalize_embedding_logits: False
logits_dot_in_fp32: False
normalization_layer_epsilon: 1.e-05
use_iota_embed: True
opt_type: "adam_pax"
I think logits_via_embedding: True
should be the most important one.
I tested these out. First running
python3 MaxText/train.py MaxText/configs/base.yml run_name="${RUNNAME}" model_name=gpt3-175b
and then also adding the other relevant flags you posted one by one and all of them start with the bad loss (7.6x). So its not flash attn, tokenizer (as validation is pretokenized and evaluated loss is also bad), config args (as i tried the flags you have suggested)
Its probably something to do with the model weights
I can take a look at full logs if you have. We should have final effective configs in that log.
Thanks. Here are the logs
Checked the log. All updated parameters matched and I didn't find anything suspicious.
Thanks for checking yeah its strange that its starting with a bad loss. I also tried testing the tokenizer and it also seems fine
The only one I found looks weird is
+ Config param weight_dtype: float32
- Config param weight_dtype: bfloat16
Could you try using weight_dtype as float32 instead of bfloat16? The activation is calculated as bfloat16 while all parameter and optimizer state should be in float32 format for better convergence.
However, I do not expect such a big gap.
Tried the weight_dtype as float32 as well. Same problem
im wondering if we can send you our converted ckpt for you to load and verify its an ckpt problem?
I can take a try in TPU side
By the way, would it be useful to you to print the mean average of each param state after conversion?
im not sure if it will be useful. We also loaded the pax ckpt directly in paxml and the ckpt starts at the right loss. So at this point, we suspect something is going wrong during conversion
It would be easiest if you have some converted ckpt, I can directly compare your converted ckpt against ours. If you have some output log in conversion script, I can take a look as well.
We didn't try that in gpu, I guess there might be something differently.
great we will share the converted ckpt and the conversion logs. Do you have a gcloud bucket that i could push it to? or do you recommend some other way?
we will share the converted ckpt and the conversion logs. Do you have a gcloud bucket that i could push it to? or do you recommend some other way?
Great if you can share with us some open gcloud bucket. By the way, which conversion script are you using? Is the one in mlperf 4.0 submission or the one in maxtext main branch?
ok, let me do that We tried both versions and with both, we are getting the same problem
We tried both versions and with both, we are getting the same problem
Gotcha, thank you for the info!
Hello, We converted the paxml checkpoint and resumed training with following config:
The tokenizer and data splits (3.0.4, 3.0.5) were downloaded from mlperf2 bucket. I have also tried using the c4_mlperf dataset_type like this:
^ scan_layers set to true in line with how we converted the ckpt
^ starts with a very high loss and we expected something closer to 2.77
We have ensured that the training loads the right checkpoint, the correct data splits and also the tokenizer from the logs