Closed arktoswb closed 4 months ago
The same error inside docker PyTorch Release 23.04: https://gist.github.com/arktoswb/d5835a666e7fcf9bfa3d7ff59173299c
Apparently, TransformerEngine is supported with model_type = 'mcore'. So, in order to continue training from GPT-345M checkpoint:
python3 tools/checkpoint/convert.py --model-type GPT --loader megatron --saver megatron
--use-mcore-models
I will close this issue, but I suggest to edit example scripts and README to make it more clear.
Hello, thanks for suggesting me to use the convert.py, but i still have some problems. Could you please help me take a look at the issue I encountered when using convert?
Here is my code:
python3 tools/checkpoint/convert.py --model-type GPT --loader megatron --saver megatron --load-dir models/megatron_lm_345m_v0.0 --save-dir models/convert/gpt2 --megatron-path /home/zyz/code/Megatron-LM-core_v0.6.0
Which reports:
File "/home/zyz/code/Megatron-LM-core_v0.6.0/tools/checkpoint/loader_megatron.py", line 70, in _load_checkpoint margs, checkpoint_args = load_args_from_checkpoint(margs, exit_on_missing_checkpoint=True) TypeError: cannot unpack non-iterable Namespace object
Can you let me see all the code you're using here?
thanks again
Hello, thanks for suggesting me to use the convert.py, but i still have some problems. Could you please help me take a look at the issue I encountered when using convert? Here is my code:
python3 tools/checkpoint/convert.py --model-type GPT --loader megatron --saver megatron --load-dir models/megatron_lm_345m_v0.0 --save-dir models/convert/gpt2 --megatron-path /home/zyz/code/Megatron-LM-core_v0.6.0
Which reports:
File "/home/zyz/code/Megatron-LM-core_v0.6.0/tools/checkpoint/loader_megatron.py", line 70, in _load_checkpoint margs, checkpoint_args = load_args_from_checkpoint(margs, exit_on_missing_checkpoint=True) TypeError: cannot unpack non-iterable Namespace object
Can you let me see all the code you're using here? thanks again
Yeah, there are multiple problems with that:
--loader megatron
: you are loading megatron model--saver megatron
: you are saving megatron modelMegatron core model is --saver mcore
.
As for loading you can make edits to the code and hardcode multiple model parameters to make it work:
load_args_from_checkpoint
returns just args
in some conditions instead of returning args, checkpoint_args
. You can bypass it by returning args, args
Once fixed, you will have more errors because loader does not know all the parameters of the model it needs to know. Add them:
check_for_arg('num_layers', 24)
check_for_arg('hidden_size', 1024)
check_for_arg('num_attention_heads', 16)
check_for_arg('max_position_embeddings', 1024)
check_for_arg('seq_length', 1024)
check_for_arg('tokenizer_type', 'GPT2BPETokenizer')
# Validate margs.
margs = validate_args(margs)
margs.use_mcore_models = False
margs.transformer_impl = args.loader_transformer_impl
check_for_arg('tensor_model_parallel_size')
check_for_arg('pipeline_model_parallel_size')
check_for_arg('position_embedding_type')
check_for_arg('iteration', 666)
check_for_arg('padded_vocab_size', 50304)
check_for_arg('bert_binary_head')
check_for_arg('disable_bias_linear', False)
check_for_arg('params_dtype')
check_for_arg('swiglu', False)
But honestly you will have easier experience loading llama2 7b model. Even better experience on NeMo - it's in a better shape and also supports llama3
Thank you very much for your response, which has been very helpful.
Hi @arktoswb , does convert.py support convert a pretrained checkpoint with PP=1, TP=1 to either PP>1 and/or TP>1? I want to finetune Mamba 8B from a pretrained checkpoint, but 1 GPU can't afford the memory constraints.
Hi @arktoswb , does convert.py support convert a pretrained checkpoint with PP=1, TP=1 to either PP>1 and/or TP>1? I want to finetune Mamba 8B from a pretrained checkpoint, but 1 GPU can't afford the memory constraints.
Yes, I believe it does.
I stopped working with Megatron months ago, so I am not the best person to ask this question.
Hi @arktoswb , does convert.py support convert a pretrained checkpoint with PP=1, TP=1 to either PP>1 and/or TP>1? I want to finetune Mamba 8B from a pretrained checkpoint, but 1 GPU can't afford the memory constraints.
Yes, I believe it does.
I stopped working with Megatron months ago, so I am not the best person to ask this question.
Thanks for replying! I did not find a flag I can specify the PP and TP at convert.py. Do you have any clues on this? Or do you know anyone who may have some clues?
Hi @arktoswb , does convert.py support convert a pretrained checkpoint with PP=1, TP=1 to either PP>1 and/or TP>1? I want to finetune Mamba 8B from a pretrained checkpoint, but 1 GPU can't afford the memory constraints.
Yes, I believe it does. I stopped working with Megatron months ago, so I am not the best person to ask this question.
Thanks for replying! I did not find a flag I can specify the PP and TP at convert.py. Do you have any clues on this? Or do you know anyone who may have some clues?
From https://github.com/NVIDIA/Megatron-LM?tab=readme-ov-file#evaluation-and-tasks: flags --target-tensor-parallel-size
and --target-pipeline-parallel-size
Hi @arktoswb , does convert.py support convert a pretrained checkpoint with PP=1, TP=1 to either PP>1 and/or TP>1? I want to finetune Mamba 8B from a pretrained checkpoint, but 1 GPU can't afford the memory constraints.
Yes, I believe it does. I stopped working with Megatron months ago, so I am not the best person to ask this question.
Thanks for replying! I did not find a flag I can specify the PP and TP at convert.py. Do you have any clues on this? Or do you know anyone who may have some clues?
From https://github.com/NVIDIA/Megatron-LM?tab=readme-ov-file#evaluation-and-tasks: flags
--target-tensor-parallel-size
and--target-pipeline-parallel-size
Thanks! It took some effort to solve the import error, but at the end I encountered another error that says for all layers
untimeError: Error(s) in loading state_dict for ParallelTransformer:
Missing key(s) in state_dict: "layers.0.input_norm.weight", "layers.0.self_attention.query_key_value.weight", "layers.0.self_attention.dense.wei ...
I think it is the problem with how to enable distributed model for Mamba model specifically. I opened up another github issue on this .
But thanks for directing to the convert.py
!
Update: convert.py
does not support Mamba at the moment, but the hybrid_conversion.py
does.
Describe the bug While running
examples/pretrain_gpt.sh
from GPT-345M checkpoint I encounter such error:To Reproduce
Run
examples/pretrain_gpt.sh
.--attention-softmax-in-fp32
arg is added (does not work otherwise). Also tried llama2 checkpoint. The similar error.However, the script successfully runs:
--transformer-impl local
from provided GPT-345M checkpoint, but that's deprecated, and will not work with llama models per my understanding.Expected behavior
examples/pretrain_gpt.sh
should run fine from GPT-345M checkpoint on the latest release without any modifications.Stack trace/logs https://gist.github.com/arktoswb/7830a87d514fd53cdad17882128d5122
Environment:
stable
andrelease_v1.1
(related: https://github.com/NVIDIA/Megatron-LM/issues/577)