AI-Hypercomputer / maxtext

A simple, performant and scalable Jax LLM!
Apache License 2.0
1.47k stars 275 forks source link

converted mlperf gpt3 ckpt starts with a worse loss #887

Open gramesh-amd opened 1 week ago

gramesh-amd commented 1 week ago

Hello, We converted the paxml checkpoint and resumed training with following config:

base_config: "base.yml"
tokenizer_path: "/dockerx/vocab/c4_en_301_5Mexp2_spm.model"
dataset_type: "tfds"
dataset_path: "/ckpts/c4_mlperf_dataset"
dataset_name: "en:3.0.4"
eval_dataset_name: "en:3.0.5"
split: "train2"
tokenize_eval_data: False
eval_data_column: "ids"
add_bos: False
add_eos: False
eval_split: "validation_tokenized_5662seqs"
eval_interval: 10  # the specific number of train step between eval_step
target_eval_loss: 2.69  # early stop once reaching target eval_loss

enable_checkpointing: True
save_interval_steps: 5

# Args coming from the NVIDIA spreadsheet http://shortn/_W9CzVbtQde and
# third_party/py/maxtext/configs/a3/llama_2_7b.
hardware: "gpu"
steps: 10
model_name: "gpt3-175b" # this model config is unchanged
attention: "cudnn_flash_te"

gradient_accumulation_steps: 1

dcn_data_parallelism: 1
dcn_fsdp_parallelism: -1
dcn_pipeline_parallelism: 1
dcn_tensor_parallelism: 1
dcn_sequence_parallelism: 1
ici_fsdp_parallelism: 8
ici_data_parallelism: 1
ici_sequence_parallelism: 1
ici_tensor_parallelism: 1
ici_pipeline_parallelism: 1
per_device_batch_size: 5
max_target_length: 2048

remat_policy: "full"
use_iota_embed: True
scan_layers: False
async_checkpointing: False
logits_dot_in_fp32: False
megablox: False

dtype: "bfloat16"
quantization: ""
quantize_kvcache: False
kv_quant_axis: "heads_and_dkv"
kv_quant_dtype: "int8"
weight_dtype: bfloat16
checkpoint_is_quantized: False # Set to True if reading from a saved aqt quantized checkpoint

skip_first_n_steps_for_profiler: 3

mesh_axes: ['data', 'stage', 'fsdp', 'fsdp_transpose', 'sequence', 'tensor', 'autoregressive']
logical_axis_rules: [
                      ['activation_batch', ['data', 'fsdp', 'fsdp_transpose',]],
                       # For pipeline parallelism the pre and post decoder layer tensors' batch dimension is sharded by stages.
                       # Microbatches are sharded by stage, so moving out of and into this sharding should be a local reshape.
                       # The "stage" needs to be listed first since the microbatch dimension is first before the reshape.
                      ['activation_embed_and_logits_batch', ['stage', 'data', 'fsdp', 'fsdp_transpose']],
                      ['activation_heads', ['tensor','sequence']],
                      ['activation_kv_heads', ['tensor','sequence']],
                      ['activation_length', 'sequence'],
                      ['activation_embed', 'tensor'],
                      ['activation_mlp', 'tensor'],
                      ['activation_kv', 'tensor'],
                      ['activation_kv_batch', ['data', 'fsdp', 'fsdp_transpose',]],
                      ['activation_kv_head_dim', 'tensor'],
                      ['activation_vocab', ['tensor', 'sequence']],
                      ['activation_vocab', 'tensor'],
                      ['activation_vocab', 'sequence'],
                      ['activation_stage','stage'],
                      ['mlp', ['fsdp_transpose', 'tensor', 'autoregressive']],
                      ['vocab', ['tensor', 'autoregressive']],
                      ['embed', ['fsdp', 'fsdp_transpose', 'sequence']],
                      ['embed', ['fsdp', 'sequence']],
                      ['norm', 'fsdp'],
                      ['heads', ['tensor', 'autoregressive']],
                      ['layers', 'stage'],
                      ['kv', []],
                      ['kv_heads', ['tensor', 'autoregressive']],
                      ['kv_head_dim', []],
                      ['cache_batch', []],
                      ['cache_heads', ['autoregressive', 'tensor']],
                      ['cache_kv', []],
                      ['cache_sequence', []],
                    ]

# Axes used for DCN must be earlier in this list than ICI, see (b/339009148) for details
data_sharding: [['data', 'stage', 'fsdp', 'fsdp_transpose', 'sequence', 'tensor', 'autoregressive']]

The tokenizer and data splits (3.0.4, 3.0.5) were downloaded from mlperf2 bucket. I have also tried using the c4_mlperf dataset_type like this:

base_config: "base.yml"
tokenizer_path: "/dockerx/vocab/c4_en_301_5Mexp2_spm.model"
dataset_type: "c4_mlperf"
dataset_path: "/ckpts/c4_mlperf_dataset"
dataset_name: "en:3.0.4"
eval_dataset_name: "en:3.0.5"
split: "train2"
eval_split: "validation_tokenized_5662seqs"
python maxtext/MaxText/train.py /dockerx/maxtext/MaxText/configs/gpt3_175b_gpu.yml base_output_directory=/ckpts/paxml/gpt3-conversion run_name=gpt3-conversion steps=4010 scan_layers=true

^ scan_layers set to true in line with how we converted the ckpt

completed step: 4000, seconds: 91.772, TFLOP/s/device: 24.021, Tokens/s/device: 22.316, total_weights: 65504, loss: 7.644, perplexity: 2088.295
To see full metrics 'tensorboard --logdir=/ckpts/paxml/gpt3-conversion/gpt3-conversion/tensorboard/'
completed step: 4001, seconds: 12.945, TFLOP/s/device: 170.297, Tokens/s/device: 158.213, total_weights: 65504, loss: 7.687, perplexity: 2179.917
completed step: 4002, seconds: 11.886, TFLOP/s/device: 185.471, Tokens/s/device: 172.310, total_weights: 65504, loss: 7.739, perplexity: 2297.215
completed step: 4003, seconds: 11.885, TFLOP/s/device: 185.479, Tokens/s/device: 172.318, total_weights: 65504, loss: 7.597, perplexity: 1992.680
completed step: 4004, seconds: 11.931, TFLOP/s/device: 184.759, Tokens/s/device: 171.649, total_weights: 65504, loss: 7.680, perplexity: 2165.097
completed step: 4005, seconds: 11.913, TFLOP/s/device: 185.043, Tokens/s/device: 171.912, total_weights: 65504, loss: 7.663, perplexity: 2128.778
completed step: 4006, seconds: 11.945, TFLOP/s/device: 184.546, Tokens/s/device: 171.451, total_weights: 65504, loss: 7.582, perplexity: 1963.248
completed step: 4007, seconds: 11.913, TFLOP/s/device: 185.048, Tokens/s/device: 171.918, total_weights: 65504, loss: 7.648, perplexity: 2096.574
completed step: 4008, seconds: 12.013, TFLOP/s/device: 183.498, Tokens/s/device: 170.478, total_weights: 65504, loss: 7.524, perplexity: 1851.645
completed step: 4009, seconds: 11.920, TFLOP/s/device: 184.929, Tokens/s/device: 171.807, total_weights: 65504, loss: 7.618, perplexity: 2034.629

^ starts with a very high loss and we expected something closer to 2.77

We have ensured that the training loads the right checkpoint, the correct data splits and also the tokenizer from the logs

gramesh-amd commented 1 week ago

@ZhiyuLi-goog thanks again for your help with other issues. Do you see any problems with the config or know why the loss is much higher?

ZhiyuLi-goog commented 1 week ago

I have never tried on GPU. To narrow down the root cause, could you try with normal attention?

attention: "dot_product" 
gramesh-amd commented 1 week ago

with attention: "dot_product" : completed step: 4000, seconds: 91.772, TFLOP/s/device: 24.021, Tokens/s/device: 22.316, total_weights: 65504, loss: 7.644, perplexity: 2088.295 To see full metrics 'tensorboard --logdir=/ckpts/paxml/gpt3-conversion/gpt3-conversion/tensorboard/' completed step: 4001, seconds: 39.677, TFLOP/s/device: 277.795, Tokens/s/device: 258.083, total_weights: 327520, loss: 7.638, perplexity: 2076.376 completed step: 4002, seconds: 39.883, TFLOP/s/device: 276.359, Tokens/s/device: 256.749, total_weights: 327520, loss: 7.646, perplexity: 2092.290

I get similar loss as before

ZhiyuLi-goog commented 1 week ago

Oh, could you try something like

python3 MaxText/train.py MaxText/configs/base.yml run_name="${RUNNAME}" model_name=gpt3-175b

instead of changing the base.yml? You can find the exact model yaml setup gpt3-175b.yml and there's some more setup for gpt3-175b.

# these flags might be relevant to output results
logits_via_embedding: True
normalize_embedding_logits: False
logits_dot_in_fp32: False
normalization_layer_epsilon: 1.e-05
use_iota_embed: True
opt_type: "adam_pax"

I think logits_via_embedding: True should be the most important one.

gramesh-amd commented 1 week ago

I tested these out. First running

python3 MaxText/train.py MaxText/configs/base.yml run_name="${RUNNAME}" model_name=gpt3-175b

and then also adding the other relevant flags you posted one by one and all of them start with the bad loss (7.6x). So its not flash attn, tokenizer (as validation is pretokenized and evaluated loss is also bad), config args (as i tried the flags you have suggested)

Its probably something to do with the model weights

ZhiyuLi-goog commented 1 week ago

I can take a look at full logs if you have. We should have final effective configs in that log.

gramesh-amd commented 1 week ago

maxtext_gpt3_logs.txt

Thanks. Here are the logs

ZhiyuLi-goog commented 1 week ago

Checked the log. All updated parameters matched and I didn't find anything suspicious.

gramesh-amd commented 1 week ago

Thanks for checking yeah its strange that its starting with a bad loss. I also tried testing the tokenizer and it also seems fine

ZhiyuLi-goog commented 1 week ago

The only one I found looks weird is

+ Config param weight_dtype: float32
- Config param weight_dtype: bfloat16

Could you try using weight_dtype as float32 instead of bfloat16? The activation is calculated as bfloat16 while all parameter and optimizer state should be in float32 format for better convergence.

However, I do not expect such a big gap.

gramesh-amd commented 1 week ago

Tried the weight_dtype as float32 as well. Same problem

im wondering if we can send you our converted ckpt for you to load and verify its an ckpt problem?

ZhiyuLi-goog commented 1 week ago

I can take a try in TPU side

By the way, would it be useful to you to print the mean average of each param state after conversion?

gramesh-amd commented 4 days ago

im not sure if it will be useful. We also loaded the pax ckpt directly in paxml and the ckpt starts at the right loss. So at this point, we suspect something is going wrong during conversion

ZhiyuLi-goog commented 4 days ago

It would be easiest if you have some converted ckpt, I can directly compare your converted ckpt against ours. If you have some output log in conversion script, I can take a look as well.

We didn't try that in gpu, I guess there might be something differently.

gramesh-amd commented 4 days ago

great we will share the converted ckpt and the conversion logs. Do you have a gcloud bucket that i could push it to? or do you recommend some other way?

ZhiyuLi-goog commented 4 days ago

we will share the converted ckpt and the conversion logs. Do you have a gcloud bucket that i could push it to? or do you recommend some other way?

Great if you can share with us some open gcloud bucket. By the way, which conversion script are you using? Is the one in mlperf 4.0 submission or the one in maxtext main branch?

gramesh-amd commented 4 days ago

ok, let me do that We tried both versions and with both, we are getting the same problem

ZhiyuLi-goog commented 4 days ago

We tried both versions and with both, we are getting the same problem

Gotcha, thank you for the info!