MAGICS-LAB / DNABERT_2

[ICLR 2024] DNABERT-2: Efficient Foundation Model and Benchmark for Multi-Species Genome
Apache License 2.0
257 stars 59 forks source link

RuntimeError: The expanded size of the tensor (673) must match the existing size (512) at non-singleton dimension 1. Target sizes: [8, 673]. Tensor sizes: [1, 512] #42

Open TheRainInSpain opened 1 year ago

TheRainInSpain commented 1 year ago

I tried to run the finetune, I put the sequences and labels to the csv files, just followed the format in the sample data. But when I ran this, error happened, I do not know why. Could anyone help me, thank you? ` trainer.train()

File "/home/u7343217/.conda/envs/canberra_city/lib/python3.10/site-packages/transformers/trainer.py", line 1553, in train return inner_training_loop(

File "/home/u7343217/.conda/envs/canberra_city/lib/python3.10/site-packages/transformers/trainer.py", line 1835, in _inner_training_loop tr_loss_step = self.training_step(model, inputs) File "/home/u7343217/.conda/envs/canberra_city/lib/python3.10/site-packages/transformers/trainer.py", line 2679, in training_step loss = self.compute_loss(model, inputs) File "/home/u7343217/.conda/envs/canberra_city/lib/python3.10/site-packages/transformers/trainer.py", line 2704, in compute_loss outputs = model(inputs) File "/home/u7343217/.conda/envs/canberra_city/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl return forward_call(*args, *kwargs) File "/home/u7343217/.conda/envs/canberra_city/lib/python3.10/site-packages/torch/nn/parallel/data_parallel.py", line 171, in forward outputs = self.parallel_apply(replicas, inputs, kwargs) File "/home/u7343217/.conda/envs/canberra_city/lib/python3.10/site-packages/torch/nn/parallel/data_parallel.py", line 181, in parallel_apply return parallel_apply(replicas, inputs, kwargs, self.device_ids[:len(replicas)]) File "/home/u7343217/.conda/envs/canberra_city/lib/python3.10/site-packages/torch/nn/parallel/parallel_apply.py", line 89, in parallel_apply output.reraise() File "/home/u7343217/.conda/envs/canberra_city/lib/python3.10/site-packages/torch/_utils.py", line 644, in reraise raise exception RuntimeError: Caught RuntimeError in replica 0 on device 0. Original Traceback (most recent call last): File "/home/u7343217/.conda/envs/canberra_city/lib/python3.10/site-packages/torch/nn/parallel/parallel_apply.py", line 64, in _worker output = module(input, kwargs) File "/home/u7343217/.conda/envs/canberra_city/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl return forward_call(*args, *kwargs) File "/home/u7343217/.conda/envs/canberra_city/lib/python3.10/site-packages/transformers/models/bert/modeling_bert.py", line 1562, in forward outputs = self.bert( File "/home/u7343217/.conda/envs/canberra_city/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl return forward_call(args, **kwargs) File "/home/u7343217/.conda/envs/canberra_city/lib/python3.10/site-packages/transformers/models/bert/modeling_bert.py", line 988, in forward buffered_token_type_ids_expanded = buffered_token_type_ids.expand(batch_size, seq_length) RuntimeError: The expanded size of the tensor (673) must match the existing size (512) at non-singleton dimension 1. Target sizes: [8, 673]. Tensor sizes: [1, 512] `

It seems that the dimension of the tensor do not match, but I did not change any code, The command I ran was, python train.py --model_name_or_path zhihan1996/DNABERT-2-117M --data_path splited/csv_files --kmer -1 --run_name DNABERT2_test1 --model_max_length 700 --per_device_train_batch_size 8 --per_device_eval_batch_size 16 --gradient_accumulation_steps 1 --learning_rate 3e-5 --num_train_epochs 3 --fp16 --save_steps 200 --output_dir output/dnabert2 --evaluation_strategy steps --eval_steps 200 --warmup_steps 50 --logging_steps 100000 --overwrite_output_dir True --log_level info --find_unused_parameters False. I set the max length hyperparameter as 700 because my sequences are long.

Zhihan1996 commented 1 year ago

Hey,

This looks like you are using the original BERT implement instead of DNABERT2 implementation of transformer, and this is where the error comes from? Do you directly use the provided codes or write code by yourself? Can you please first fix the transformers version by running pip install transformers==4.29?