microsoft / simulated-trial-and-error

MIT License
103 stars 11 forks source link

Error Occurring at loss.backward() Despite Loss Being Calculable #14

Open aidarikako opened 2 months ago

aidarikako commented 2 months ago

Hello,

I am encountering an issue when running the following code snippet:

CUDA_VISIBLE_DEVICES=0,1,2,3` torchrun --nnodes 1 --nproc_per_node 4 llama_finetuning.py \ --enable_fsdp \ --model_name \ --num_epochs 2 \ --batch_size_training 16 \ --micro_batch_size 1 \ --val_batch_size 8 \ --lr 2e-5 \ --num_workers_dataloader 1 \ --seed 42 \ --data_path \ --max_words_dataset 2048 \ --checkpoint_folder \ --save_with_hf \ --warmup_ratio 0.03 \ --save_epoch_interval 1 \ --add_token_list ft_datasets/toolken_list_50.json

This results in the following error:

File "/mnt/home/xlh/code/simulated-trial-and-error-main/simulated-trial-and-error-main/llama-recipes/utils/train_utils.py", line 94, in train loss.backward() File "/usr/local/miniconda3/envs/ste/lib/python3.10/site-packages/torch/_tensor.py", line 522, in backward torch.autograd.backward( File "/usr/local/miniconda3/envs/ste/lib/python3.10/site-packages/torch/autograd/init.py", line 266, in backward Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass File "/usr/local/miniconda3/envs/ste/lib/python3.10/site-packages/torch/utils/checkpoint.py", line 1112, in unpack_hook frame.recompute_fn(args) File "/usr/local/miniconda3/envs/ste/lib/python3.10/site-packages/torch/utils/checkpoint.py", line 1401, in recompute_fn fn(args, kwargs) File "/usr/local/miniconda3/envs/ste/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl return self._call_impl(*args, *kwargs) File "/usr/local/miniconda3/envs/ste/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl return forward_call(args, kwargs) File "/usr/local/miniconda3/envs/ste/lib/python3.10/site-packages/transformers/models/llama/modeling_llama.py", line 741, in forward hidden_states, self_attn_weights, present_key_value = self.self_attn( File "/usr/local/miniconda3/envs/ste/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl return self._call_impl(*args, *kwargs) File "/usr/local/miniconda3/envs/ste/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl return forward_call(args, **kwargs) File "/usr/local/miniconda3/envs/ste/lib/python3.10/site-packages/transformers/models/llama/modeling_llama.py", line 671, in forward attn_output = torch.nn.functional.scaled_dot_product_attention( RuntimeError: The expanded size of the tensor (4096) must match the existing size (2048) at non-singleton dimension 3. Target sizes: [1, 32, 2048, 4096]. Tensor sizes: [1, 1, 2048, 2048]

The error occurs at loss.backward() but the loss value is computed successfully and can be printed out. I would appreciate any insights or suggestions on possible causes for this error.

Thank you for your help!

Boshi-Wang commented 1 month ago

Sorry for the late reply. I ran the code on my side and did not observe this issue. Seems to be some kind of mismatch on the tensor shapes; have you tried to trace back from the error tensor?

aidarikako commented 1 month ago

Sorry for the late reply. I ran the code on my side and did not observe this issue. Seems to be some kind of mismatch on the tensor shapes; have you tried to trace back from the error tensor?

After thorough investigation, I discovered that the issue actually arose from a line a bit before the one where the error was reported. https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py

if past_key_value is not None:
    # sin and cos are specific to RoPE models; cache_position needed for the static cache
    cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
    key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)

Specifically, after the operation past_key_value.update, the last dimension of key_states changed from 2048 to 4096. Upon closer inspection through print statements, this change didn't affect all instances; initially, the dimension remained at 2048, but starting from a certain sample, key_states turned into 4096 after past_key_value.update, subsequently causing the error and terminating the program. Therefore, the problem lies within the past_key_value.update operation.

As a temporary solution, I commented out the section involving past_key_value.update, and the model fine-tuning proceeded without errors. Do you have any possible insights regarding this bug?