texttron / tevatron

Tevatron - A flexible toolkit for neural retrieval research and development.
http://tevatron.ai
Apache License 2.0
494 stars 94 forks source link

unable to train #134

Open riyajatar37003 opened 3 months ago

riyajatar37003 commented 3 months ago

i installed the tevatron using pip install tevatron

and trying to train the repllama from the examples folder but i am getting following import error

Traceback (most recent call last): File un_3/tevatron/examples/repllama/train.py", line 12, in from tevatron.arguments import ModelArguments, DataArguments, \ ImportError: cannot import name 'TevatronTrainingArguments' from 'tevatron.arguments' (/tmp/.local/lib/python3.10/site-packages/tevatron/arguments.py) Traceback (most recent call last): File "un_3/tevatron/examples/repllama/train.py", line 12, in from tevatron.arguments import ModelArguments, DataArguments, \ ImportError: cannot import name 'TevatronTrainingArguments' from 'tevatron.arguments' (/tmp/.local/lib/python3.10/site-packages/tevatron/arguments.py) Traceback (most recent call last): File "un_3/tevatron/examples/repllama/train.py", line 12, in from tevatron.arguments import ModelArguments, DataArguments, \ ImportError: cannot import name 'TevatronTrainingArguments' from 'tevatron.arguments' (/tmp/.local/lib/python3.10/site-packages/tevatron/arguments.py)

could u give simple steps to train repllama ?

MXueguang commented 3 months ago

Hi, please clone the repo and install it via "pip install -e ." Then follow the command on main page would be able to reproduce the repllama/repmistral

deepspeed --include localhost:0,1,2,3 --master_port 60000 --module tevatron.retriever.driver.train \ --deepspeed deepspeed/ds_zero3_config.json \ --output_dir retriever-mistral \ --model_name_or_path mistralai/Mistral-7B-v0.1 \ --lora \ --lora_target_modules q_proj,k_proj,v_proj,o_proj,down_proj,up_proj,gate_proj \ --save_steps 50 \ --dataset_name Tevatron/msmarco-passage-aug \ --query_prefix "Query: " \ --passage_prefix "Passage: " \ --bf16 \ --pooling eos \ --append_eos_token \ --normalize \ --temperature 0.01 \ --per_device_train_batch_size 8 \ --gradient_checkpointing \ --train_group_size 16 \ --learning_rate 1e-4 \ --query_max_len 32 \ --passage_max_len 156 \ --num_train_epochs 1 \ --logging_steps 10 \ --overwrite_output_dir \ --gradient_accumulation_steps 4

riyajatar37003 commented 2 months ago

i am getting this error

UserWarning: torch.utils.checkpoint: please pass in use_reentrant=True or use_reentrant=False explicitly. The default value of use_reentrant will be updated to be False in the future. To maintain current behavior, pass use_reentrant=True. It is recommended that you use use_reentrant=False. Refer to docs for more details on the differences between the two variants.

i set it to true as well false in traininig_args.gradien_checkpoint kwargs but still same error

MXueguang commented 1 month ago

what is your pytorch version?