syncdoth / RetNet

Huggingface compatible implementation of RetNet (Retentive Networks, https://arxiv.org/pdf/2307.08621.pdf) including parallel, recurrent, and chunkwise forward.
MIT License
226 stars 24 forks source link

gradient_checkpointing=True issue in TrainerArgument #28

Closed lolshuo closed 10 months ago

lolshuo commented 10 months ago

I'm using the Retnet base config with the following TrainingArguments:

args = TrainingArguments( output_dir="/content/retnet-xsum", per_device_train_batch_size=1, per_device_eval_batch_size=1, evaluation_strategy="steps", eval_steps=370, logging_steps=370, num_train_epochs=10, weight_decay=0.01, warmup_steps=10, lr_scheduler_type="cosine", learning_rate=6e-4, gradient_accumulation_steps=4, gradient_checkpointing=True, dataloader_pin_memory=True, dataloader_num_workers=4, # torch_compile=True, # checkpointing: save_steps=370, optim="adafactor", # optim="adamw_torch", fp16=True, # push_to_hub=True, )

But I'm getting this error when running trainer.train();

image

syncdoth commented 10 months ago

Recent commit should solve this issue