texttron / tevatron

Tevatron - A flexible toolkit for neural retrieval research and development.
http://tevatron.ai
Apache License 2.0
435 stars 87 forks source link

DDP training seems to pass the same data to each GPU #89

Closed aken12 closed 1 month ago

aken12 commented 9 months ago

Hi :)

I'm attempting to train a model using distributed learning with multiple GPUs, but it appears that each GPU is receiving the same data. (I observed that tensors with the same values are present on cuda:0 ~ cuda:3.) I'm using the following command based on the example_dpr.md guide:

python -m torch.distributed.launch --nproc_per_node=4 -m tevatron.driver.train \
  --output_dir model_nq \
  --model_name_or_path bert-base-uncased \
  --save_steps 20000 \
  --dataset_name Tevatron/wikipedia-nq \
  --fp16 \
  --per_device_train_batch_size 32 \
  --positive_passage_no_shuffle \
  --train_n_passages 2 \
  --learning_rate 1e-5 \
  --q_max_len 32 \
  --p_max_len 156 \
  --num_train_epochs 40 \
  --logging_steps 500 \
  --negatives_x_device \
  --overwrite_output_dir

To address this issue, I customized the get_train_dataloader function in the TevatronTrainer using torch.utils.data.distributed.DistributedSampler to ensure different data is passed to each GPU. However, it seems others are able to use this without such customization. I'm wondering if I might be misunderstanding something or missing a step. I look forward to your response.

MXueguang commented 8 months ago

Hi @aken12, sorry for the late reply. Does the issue still exist? What was the transformers version you used?

aken12 commented 8 months ago

Thank you for getting back to me.

I am using transformers version 4.31.0. In the documentation, it says:

Note: The current code base has been tested with, torch==1.10.1, faiss-cpu==1.7.2, transformers==4.15.0, datasets==1.17.0

Should I downgrade my transformers version to 4.15.0 to ensure compatibility?

MXueguang commented 8 months ago

I tried 4.34.0 recently and it seems has some issue with data loading.

downgrading to transformers==4.20.0 should work.

I'll take a look at what is the difference on the transformers' side.

aken12 commented 8 months ago

I followed your advice and downgraded to transformers==4.20.0, and everything is working smoothly now. I appreciate your attentive support!