When I was preparing to train the qwen2 model using the gritlm project, I encountered this error:AssertionError: Some of models are not wrapped in DistributedDataParallel. Make sure you are running DDP with proper initializations. #43

YuvalCheung commented 6 days ago

When I executed bash, I encountered the following error:

[default2]:[rank2]: AssertionError: Some of models are not wrapped in DistributedDataParallel. Make sure you are running DDP with proper initializations.
[default0]:[rank0]: Traceback (most recent call last):
[default0]:[rank0]:   File "/usr/local/miniconda3/envs/gritlm/lib/python3.10/", line 196, in _run_module_as_main
[default0]:[rank0]:     return _run_code(code, main_globals, None,
[default0]:[rank0]:   File "/usr/local/miniconda3/envs/gritlm/lib/python3.10/", line 86, in _run_code
[default0]:[rank0]:     exec(code, run_globals)
[default0]:[rank0]:   File "/mnt/home/zhangyu/gritlm-main/gritlm/training/", line 440, in <module>
[default0]:[rank0]:     main()
[default0]:[rank0]:   File "/mnt/home/zhangyu/gritlm-main/gritlm/training/", line 422, in main
[default0]:[rank0]:     trainer.train()
[default0]:[rank0]:   File "/usr/local/miniconda3/envs/gritlm/lib/python3.10/site-packages/transformers/", line 1932, in train
[default0]:[rank0]:     return inner_training_loop(
[default0]:[rank0]:   File "/mnt/home/zhangyu/gritlm-main/gritlm/training/", line 691, in _inner_training_loop
[default0]:[rank0]:     loss_emb = gc(inputs["query"], inputs["passage"], no_sync_except_last=no_sync_except_last)
[default0]:[rank0]:   File "/usr/local/miniconda3/envs/gritlm/lib/python3.10/site-packages/grad_cache/", line 70, in __call__
[default0]:[rank0]:     return self.cache_step(*args, **kwargs)
[default0]:[rank0]:   File "/usr/local/miniconda3/envs/gritlm/lib/python3.10/site-packages/grad_cache/", line 262, in cache_step
[default0]:[rank0]:     assert all(map(lambda m: isinstance(m, nn.parallel.DistributedDataParallel), self.models)), \
[default0]:[rank0]: AssertionError: Some of models are not wrapped in DistributedDataParallel. Make sure you are running DDP with proper initializations.
[default0]:wandb: - 0.000 MB of 0.000 MB uploaded
[default0]:wandb: You can sync this run to the cloud by running:
[default0]:wandb: wandb sync /mnt/home/zhangyu/gritlm-main/gritlm/wandb/offline-run-20240630_002800-8j46ny1y
[default0]:wandb: Find logs at: ./wandb/offline-run-20240630_002800-8j46ny1y/logs
[default0]:wandb: WARNING The new W&B backend becomes opt-out in version 0.18.0; try it out with `wandb.require("core")`! See for more information.

In order to train the model with qwen2, I modified the file:

#SBATCH --job-name=gritlm
#SBATCH --nodes=1
#SBATCH --ntasks-per-node=1          # crucial - only 1 task per dist per node!
#SBATCH --hint=nomultithread         # we get physical cores not logical
#SBATCH --partition=a3
#SBATCH --gres=gpu:8                 # number of gpus
#SBATCH --time 999:00:00             # maximum execution time (HH:MM:SS)
#SBATCH --output=/data/niklas/jobs/%x-%j.out           # output file name
#SBATCH --exclusive

### Set enviroment ###
cd /mnt/home/zhangyu/gritlm-main/gritlm
export WANDB_PROJECT="gritlm"
# export WANDB_PROJECT="gritlm"
export WANDB_MODE="offline"
# so processes know who to talk to

# OUT_DIR="/mnt/home/zhangyu/output/test_llama3_8b_6"
# MODEL="/mnt/tenant-home_speed/AIM/model/llama3-8b-Instruct"

LAUNCHER="accelerate launch \
    --config_file /mnt/home/zhangyu/gritlm-main/scripts/configs/config_8gpusfsdp_m7_qwen.yml \
    --num_machines $NNODES \
    --num_processes $WORLD_SIZE \
    --main_process_ip "$MASTER_ADDR" \
    --main_process_port $MASTER_PORT \
    --machine_rank $NODE_RANK \
    --rdzv_conf rdzv_backend=c10d \
    --max_restarts 0 \
    --tee 3 \

export CMD=" \
    -m \
    --output_dir $OUT_DIR \
    --model_name_or_path $MODEL \
    --train_data $DATA_DIR\
    --learning_rate 2e-5 \
    --lr_scheduler_type linear \
    --warmup_ratio 0.03 \
    --max_steps 1253 \
    --per_device_train_batch_size 1 \
    --gradient_accumulation_steps 2 \
    --per_device_generative_bs 1 \
    --dataloader_drop_last \
    --normalized \
    --temperature 0.02 \
    --train_group_size 2 \
    --negatives_cross_device \
    --query_max_len 256 \
    --passage_max_len 2048 \
    --mode unified \
    --logging_steps 1 \
    --bf16 \
    --pooling_method mean \
    --use_unique_indices \
    --loss_gen_type mixed \
    --attn bbcc \
    --attn_implementation sdpa \
    --no_gen_gas \
    --gradient_checkpointing \
    --report_to "tensorboard" \
    --save_strategy "epoch" \
    --num_train_epochs 1 \
    --save_steps 5000 

    --wait=60 \
    --kill-on-bad-exit=1 \

# clear; srun $SRUN_ARGS --jobid $SLURM_JOB_ID bash -c "$LAUNCHER $CMD" 2>&1

bash -c "$LAUNCHER $CMD"

And in the config_8gpusfsdp_m7_qwen.yml file, I set fsdp_transformer_layer_cls_to_wrap to Qwen2DecoderLayer.

At the same time, I modified the /lib/python3.10/site-packages/transformers/models/qwen2/ file into according to the differences between and, and overridden the former according to the requirements of the gritlm project.

Here is my file:

Finally, when I executed the command bash, I encountered an error at the beginning, and here is the log record:


Now I am curious whether there is an issue with my .sh script settings or with the modifications made to the qwen2 model. Thank you very much for your assistance.

Muennighoff commented 6 days ago

You need to cd into gritlm/training/GradCache & run pip install -e . in order to get this change

It seems like you installed GradCache from their repo but the version in this repo needs to be installed.

Muennighoff commented 6 days ago

Adjusted the README a bit, lmk if this is better:

YuvalCheung commented 5 days ago

Thank you for your help. After installing GradCache with the correct method, the previous error is no longer occurring, but I have encountered another error. Have you encountered similar errors before?

Muennighoff commented 5 days ago

I think this is the timeout issue listed at the top here:

YuvalCheung commented 5 days ago

The situation described in a Known issue is identical to the phenomenon I'm encountering. Currently, I'm using a GPU configuration of 1*8. Could you please advise me on how to ensure that the saving process won't be terminated?

Also, I have a question about the file structure of the model I generated after training. The file structure is as follows:

(gritlm) root@ctmt240625013845lar-558799cd4d-sdndz:/mnt/home/zhangyu/output/test_mistral_2# tree
├── checkpoint-37
│   ├── optimizer.bin
│   ├── pytorch_model.bin
│   ├── pytorch_model_fsdp.bin
│   ├── rng_state_0.pth
│   ├── rng_state_1.pth
│   ├── rng_state_2.pth
│   ├── rng_state_3.pth
│   ├── rng_state_4.pth
│   ├── rng_state_5.pth
│   ├── rng_state_6.pth
│   ├── rng_state_7.pth
│   ├──
│   ├── special_tokens_map.json
│   ├── tokenizer.json
│   ├── tokenizer.model
│   ├── tokenizer_config.json
│   ├── trainer_state.json
│   └── training_args.bin
├── dataset_num_samples.json
└── runs
    └── Jul01_01-10-06_ctmt240625013845lar-558799cd4d-sdndz
        └── events.out.tfevents.1719767875.ctmt240625013845lar-558799cd4d-sdndz.771494.0

3 directories, 20 files

It seems quite different from the file structure in the link

# tree
├── config.json
├── dataset_num_samples.json
├── generation_config.json
├── model-00001-of-00003.safetensors
├── model-00002-of-00003.safetensors
├── model-00003-of-00003.safetensors
├── model.safetensors.index.json
├── pytorch_model-00001-of-00003.bin
├── pytorch_model-00002-of-00003.bin
├── pytorch_model-00003-of-00003.bin
├── pytorch_model.bin.index.json
├── special_tokens_map.json
├── tokenizer.json
├── tokenizer.model
├── tokenizer_config.json
└── training_args.bin

0 directories, 18 files

Thank you for your help.

Muennighoff commented 4 days ago

1) Sorry I don't know how to solve it besides what is mentioned in the Known issues section. 2) We shard the ckpt via main/scripts/ for easier usage. Added this to the README

YuvalCheung commented 4 days ago

Thank you for your help.

When I don't set no_emb_gas and no_gen_gas to True, the nccl timeout issue disappears. Should these two options have no effect on the model's capabilities?

Also, after training the model, I obtained these files:

(gritlm) root@ctmt240625013845lar-558799cd4d-sdndz:/mnt/home/zhangyu/output/test_mistral_13# tree .
├── checkpoint-1
│   ├── model.safetensors
│   ├── special_tokens_map.json
│   ├── tokenizer.json
│   ├── tokenizer.model
│   ├── tokenizer_config.json
│   ├── trainer_state.json
│   └── training_args.bin
├── checkpoint-2
│   ├── model.safetensors
│   ├── special_tokens_map.json
│   ├── tokenizer.json
│   ├── tokenizer.model
│   ├── tokenizer_config.json
│   ├── trainer_state.json
│   └── training_args.bin
├── config.json
├── dataset_num_samples.json
├── full_state_dict
│   ├── model.safetensors
│   ├── special_tokens_map.json
│   ├── tokenizer.json
│   ├── tokenizer.model
│   ├── tokenizer_config.json
│   └── training_args.bin
├── model.safetensors
├── runs
│   └── Jul02_00-46-47_ctmt240625013845lar-558799cd4d-sdndz
│       └── events.out.tfevents.1719852887.ctmt240625013845lar-558799cd4d-sdndz.1341295.0

Then I copied the config.json file to the checkpoint-1 directory.

When deploying the model using vllm, an error occurs:

vllm deployment script:

CUDA_VISIBLE_DEVICES=0 python3 -m vllm.entrypoints.openai.api_server \
    --served-model-name ZTEAIM-Gritm-Base \
    --model "/mnt/home/zhangyu/output/test_mistral_13/checkpoint-1" \
    --port 6000 \
    --tensor-parallel-size 1 \
    --gpu-memory-utilization 0.95 \
    --dtype bfloat16 \
    --max-model-len 4096 \
    --api-key 10344626 \

I received the following error message:

(vllm) root@ctmt240625013845lar-558799cd4d-sdndz:/mnt/home/zhangyu/model_deployment# bash 
INFO 07-02 10:03:12] args: Namespace(host=None, port=6000, allow_credentials=False, allowed_origins=['*'], allowed_methods=['*'], allowed_headers=['*'], api_key='10344626', served_model_name='Gritm-Base', chat_template=None, response_role='assistant', ssl_keyfile=None, ssl_certfile=None, root_path=None, middleware=[], model='/mnt/home/zhangyu/output/test_mistral_13/checkpoint-1', tokenizer=None, revision=None, tokenizer_revision=None, tokenizer_mode='auto', trust_remote_code=False, download_dir=None, load_format='auto', dtype='bfloat16', kv_cache_dtype='auto', max_model_len=4096, worker_use_ray=False, pipeline_parallel_size=1, tensor_parallel_size=1, max_parallel_loading_workers=None, block_size=16, seed=0, swap_space=4, gpu_memory_utilization=0.95, max_num_batched_tokens=None, max_num_seqs=256, max_paddings=256, disable_log_stats=False, quantization=None, enforce_eager=False, max_context_len_to_capture=8192, disable_custom_all_reduce=False, enable_lora=False, max_loras=1, max_lora_rank=16, lora_extra_vocab_size=256, lora_dtype='auto', max_cpu_loras=None, device='cuda', engine_use_ray=False, disable_log_requests=False, max_log_len=None)
INFO 07-02 10:03:13] Initializing an LLM engine with config: model='/mnt/home/zhangyu/output/test_mistral_13/checkpoint-1', tokenizer='/mnt/home/zhangyu/output/test_mistral_13/checkpoint-1', tokenizer_mode=auto, revision=None, tokenizer_revision=None, trust_remote_code=False, dtype=torch.bfloat16, max_seq_len=4096, download_dir=None, load_format=auto, tensor_parallel_size=1, disable_custom_all_reduce=False, quantization=None, enforce_eager=False, kv_cache_dtype=auto, device_config=cuda, seed=0)
Traceback (most recent call last):
  File "/usr/local/miniconda3/envs/vllm/lib/python3.10/", line 196, in _run_module_as_main
    return _run_code(code, main_globals, None,
  File "/usr/local/miniconda3/envs/vllm/lib/python3.10/", line 86, in _run_code
    exec(code, run_globals)
  File "/usr/local/miniconda3/envs/vllm/lib/python3.10/site-packages/vllm/entrypoints/openai/", line 217, in <module>
    engine = AsyncLLMEngine.from_engine_args(engine_args)
  File "/usr/local/miniconda3/envs/vllm/lib/python3.10/site-packages/vllm/engine/", line 625, in from_engine_args
    engine = cls(parallel_config.worker_use_ray,
  File "/usr/local/miniconda3/envs/vllm/lib/python3.10/site-packages/vllm/engine/", line 321, in __init__
    self.engine = self._init_engine(*args, **kwargs)
  File "/usr/local/miniconda3/envs/vllm/lib/python3.10/site-packages/vllm/engine/", line 366, in _init_engine
    return engine_class(*args, **kwargs)
  File "/usr/local/miniconda3/envs/vllm/lib/python3.10/site-packages/vllm/engine/", line 120, in __init__
  File "/usr/local/miniconda3/envs/vllm/lib/python3.10/site-packages/vllm/engine/", line 164, in _init_workers
  File "/usr/local/miniconda3/envs/vllm/lib/python3.10/site-packages/vllm/engine/", line 1012, in _run_workers
    driver_worker_output = getattr(self.driver_worker,
  File "/usr/local/miniconda3/envs/vllm/lib/python3.10/site-packages/vllm/worker/", line 102, in load_model
  File "/usr/local/miniconda3/envs/vllm/lib/python3.10/site-packages/vllm/worker/", line 84, in load_model
    self.model = get_model(self.model_config, self.device_config,
  File "/usr/local/miniconda3/envs/vllm/lib/python3.10/site-packages/vllm/model_executor/", line 86, in get_model
    model.load_weights(model_config.model, model_config.download_dir,
  File "/usr/local/miniconda3/envs/vllm/lib/python3.10/site-packages/vllm/model_executor/models/", line 374, in load_weights
    param = params_dict[name]
KeyError: 'model.lm_head.weight'

If I change the model I'm deploying to GritLM-7B, then the deployment is successful:

(vllm) root@ctmt240625013845lar-558799cd4d-sdndz:/mnt/home/zhangyu/model_deployment# bash 
INFO 07-02 10:04:54] args: Namespace(host=None, port=6000, allow_credentials=False, allowed_origins=['*'], allowed_methods=['*'], allowed_headers=['*'], api_key='10344626', served_model_name='Gritm-Base', chat_template=None, response_role='assistant', ssl_keyfile=None, ssl_certfile=None, root_path=None, middleware=[], model='/mnt/tenant-home_speed/AIM/model/GritLM-7B', tokenizer=None, revision=None, tokenizer_revision=None, tokenizer_mode='auto', trust_remote_code=False, download_dir=None, load_format='auto', dtype='bfloat16', kv_cache_dtype='auto', max_model_len=4096, worker_use_ray=False, pipeline_parallel_size=1, tensor_parallel_size=1, max_parallel_loading_workers=None, block_size=16, seed=0, swap_space=4, gpu_memory_utilization=0.95, max_num_batched_tokens=None, max_num_seqs=256, max_paddings=256, disable_log_stats=False, quantization=None, enforce_eager=False, max_context_len_to_capture=8192, disable_custom_all_reduce=False, enable_lora=False, max_loras=1, max_lora_rank=16, lora_extra_vocab_size=256, lora_dtype='auto', max_cpu_loras=None, device='cuda', engine_use_ray=False, disable_log_requests=False, max_log_len=None)
INFO 07-02 10:04:54] Initializing an LLM engine with config: model='/mnt/tenant-home_speed/AIM/model/GritLM-7B', tokenizer='/mnt/tenant-home_speed/AIM/model/GritLM-7B', tokenizer_mode=auto, revision=None, tokenizer_revision=None, trust_remote_code=False, dtype=torch.bfloat16, max_seq_len=4096, download_dir=None, load_format=auto, tensor_parallel_size=1, disable_custom_all_reduce=False, quantization=None, enforce_eager=False, kv_cache_dtype=auto, device_config=cuda, seed=0)
INFO 07-02 10:05:01] # GPU blocks: 31287, # CPU blocks: 2048
INFO 07-02 10:05:06] Capturing the model for CUDA graphs. This may lead to unexpected consequences if the model is not static. To run the model in eager mode, set 'enforce_eager=True' or use '--enforce-eager' in the CLI.
INFO 07-02 10:05:06] CUDA graphs can take additional 1~3 GiB memory per GPU. If you are running out of memory, consider decreasing `gpu_memory_utilization` or enforcing eager mode. You can also reduce the `max_num_seqs` as needed to decrease memory usage.
INFO 07-02 10:05:09] Graph capturing finished in 4 secs.
INFO 07-02 10:05:10] Using default chat template:
INFO 07-02 10:05:10] {{ bos_token }}{% for message in messages %}
INFO 07-02 10:05:10] {% if message['role'] == 'user' %}
INFO 07-02 10:05:10] {{ '<|user|>
INFO 07-02 10:05:10] ' + message['content'] }}
INFO 07-02 10:05:10] {% elif message['role'] == 'assistant' %}
INFO 07-02 10:05:10] {{ '<|assistant|>
INFO 07-02 10:05:10] '  + message['content'] + eos_token }}
INFO 07-02 10:05:10] {% endif %}
INFO 07-02 10:05:10] {% if loop.last and add_generation_prompt %}
INFO 07-02 10:05:10] {{ '<|assistant|>' }}
INFO 07-02 10:05:10] {% endif %}
INFO 07-02 10:05:10] {% endfor %}
INFO:     Started server process [1554267]
INFO:     Waiting for application startup.
INFO:     Application startup complete.
INFO:     Uvicorn running on (Press CTRL+C to quit)
INFO 07-02 10:05:20] Avg prompt throughput: 0.0 tokens/s, Avg generation throughput: 0.0 tokens/s, Running: 0 reqs, Swapped: 0 reqs, Pending: 0 reqs, GPU KV cache usage: 0.0%, CPU KV cache usage: 0.0%
INFO 07-02 10:05:30] Avg prompt throughput: 0.0 tokens/s, Avg generation throughput: 0.0 tokens/s, Running: 0 reqs, Swapped: 0 reqs, Pending: 0 reqs, GPU KV cache usage: 0.0%, CPU KV cache usage: 0.0%

Is there something missing from the model I trained? Here is my training script:

#SBATCH --job-name=gritlm
#SBATCH --nodes=1
#SBATCH --ntasks-per-node=1          # crucial - only 1 task per dist per node!
#SBATCH --hint=nomultithread         # we get physical cores not logical
#SBATCH --partition=a3
#SBATCH --gres=gpu:8                 # number of gpus
#SBATCH --time 999:00:00             # maximum execution time (HH:MM:SS)
#SBATCH --output=/data/niklas/jobs/%x-%j.out           # output file name
#SBATCH --exclusive

### Set enviroment ###
cd /mnt/home/zhangyu/gritlm-main/gritlm
export WANDB_PROJECT="gritlm"
export WANDB_MODE="offline"
# so processes know who to talk to

# OUT_DIR="/mnt/home/zhangyu/output/test_llama3_8b_7"
# OUT_DIR="/mnt/home/zhangyu/output/test_qwen2_10"

# MODEL="/mnt/tenant-home_speed/AIM/model/qwen2_7B_chat"
# MODEL="/mnt/tenant-home_speed/AIM/model/llama3-8b-Instruct"


# YMLPATH="/mnt/home/zhangyu/gritlm-main/scripts/configs/config_8gpusfsdp_m7_qwen.yml"
# YMLPATH="/mnt/home/zhangyu/gritlm-main/scripts/configs/config_8gpusfsdp_m7_llama.yml"
# YMLPATH="/mnt/home/zhangyu/gritlm-main/scripts/configs/config_8gpusddp_m7.yml"

LAUNCHER="accelerate launch \
    --config_file $YMLPATH \
    --num_machines $NNODES \
    --num_processes $WORLD_SIZE \
    --main_process_ip "$MASTER_ADDR" \
    --main_process_port $MASTER_PORT \
    --machine_rank $NODE_RANK \
    --role $SLURMD_NODENAME: \
    --rdzv_conf rdzv_backend=c10d \
    --max_restarts 0 \
    --tee 3 \

export CMD=" \
    -m \
    --output_dir $OUT_DIR \
    --model_name_or_path $MODEL \
    --train_data $DATA_DIR\
    --learning_rate 2e-5 \
    --lr_scheduler_type linear \
    --warmup_ratio 0.03 \
    --per_device_train_batch_size 1 \
    --gradient_accumulation_steps 1 \
    --dataloader_drop_last \
    --normalized \
    --temperature 0.02 \
    --train_group_size 2 \
    --negatives_cross_device \
    --query_max_len 256 \
    --passage_max_len 2048 \
    --mode unified \
    --logging_steps 1 \
    --bf16 \
    --pooling_method mean \
    --use_unique_indices \
    --loss_gen_factor 0.003 \
    --loss_gen_type token \
    --attn bbcc \
    --attn_implementation sdpa \
    --gradient_checkpointing \
    --report_to "tensorboard" \
    --save_strategy "epoch" \
    --save_steps 1 \
    --save_only_model \
    --save_safetensors \
    --max_steps 1500 \
    --ddp_backend gloo \
    --num_train_epochs 1

    --wait=60 \
    --kill-on-bad-exit=1 \

    # --no_gen_gas \
    # --no_emb_gas \
#     --no_gen_gas \
#       --split_emb \
    # --split_emb_full \
# clear; srun $SRUN_ARGS --jobid $SLURM_JOB_ID bash -c "$LAUNCHER $CMD" 2>&1

#     --max_steps 1253 \

    # --save_strategy "epoch" \

bash -c "$LAUNCHER $CMD"

Thank you again for answering my questions.

Muennighoff commented 4 days ago

When I don't set no_emb_gas and no_gen_gas to True, the nccl timeout issue disappears. Should these two options have no effect on the model's capabilities?

It should not impact capabilities. Maybe it has something to do with the trainer as they will use a different trainer, see explained here:

Is there something missing from the model I trained? Here is my training script:

I don't notice a major problem. I would check the safetensors file and compare its keys with the keys of the GritLM-7B model files. Just load them each in memory and check that they have the exact same keys.

Also FYI GritLM-7B was trained from the base mistral 7b not the instruct version like you are doing, but I don't think it matters a lot.

YuvalCheung commented 3 days ago

I think I know why the model loading failed. I used the following code to inspect the generated .safetensors file:

from safetensors.torch import safe_open
st_file = "/mnt/home/zhangyu/output/test_mistral_16/checkpoint-1/model.safetensors"
with safe_open(st_file, framework="pt") as f:
    for name in f.keys():
        param = f.get_tensor(name)

I found that the format of these names in my model is like this:


But when I used the same method to inspect the gritlm7b model, and the qwen2 model, the output looked like this:


Obviously, the model I trained has an extra "model." prefix in the names.

I'm not sure what parameter is causing this phenomenon; it could be an environment issue or a library version problem. The only solution I can think of right now is to remove the prefix from the names after training the model with code like this. It's a silly method, but it should work.

from safetensors.torch import safe_open
from safetensors.torch import save_file
import sys
# st_file = "/mnt/home/zhangyu/output/test_mistral_17/checkpoint-1/model.safetensors"

st_file = sys.argv[1]
new_file = "/mnt/home/zhangyu/output/test/model.safetensors"
save_dict = {}
with safe_open(st_file, framework="pt") as f:
    for name in f.keys():
        param = f.get_tensor(name)
        new_name = name[6:]
        save_dict[new_name] = param
save_file(save_dict, new_file)

I don't know if others will encounter this problem, but I will investigate the cause of this issue later. For now, I'll focus on getting the project to run. Thank you for patiently answering my questions.

Muennighoff commented 3 days ago

Oh yes I think that is expected & there is a script for that here:

I've added it to the README, sorry!

YuvalCheung commented 3 days ago

Thank you so much for your answer!

I have one more question: Is it possible for me to move the operation that saves the config.json in to before the training starts?

The modified code looks like this:

    # Save tokenizer & config for easy usage afterwards
    if trainer.is_world_process_zero(): 
        config.to_json_file(training_args.output_dir + "/config.json")

    # Training"Starting training")

    # The below does not save if state dict type is `SHARDED_STATE_DICT`

    # To be safe do another FS save
    if (trainer.is_fsdp_enabled) and (trainer.accelerator.state.fsdp_plugin.state_dict_type != "FULL_STATE_DICT"):
        fsd_path = os.path.join(training_args.output_dir, "full_state_dict")
        os.makedirs(fsd_path, exist_ok=True)

Since I might take out a certain checkpoint during training for deployment, and the deployment requires the config.json file. Will there be any impact if I make this change?

Muennighoff commented 3 days ago

I think that should be fine