Huggingface compatible implementation of RetNet (Retentive Networks, https://arxiv.org/pdf/2307.08621.pdf) including parallel, recurrent, and chunkwise forward.
MIT License
226
stars
24
forks
source link
gradient_checkpointing=True issue in TrainerArgument #28
I'm using the Retnet base config with the following TrainingArguments:
args = TrainingArguments( output_dir="/content/retnet-xsum", per_device_train_batch_size=1, per_device_eval_batch_size=1, evaluation_strategy="steps", eval_steps=370, logging_steps=370, num_train_epochs=10, weight_decay=0.01, warmup_steps=10, lr_scheduler_type="cosine", learning_rate=6e-4, gradient_accumulation_steps=4, gradient_checkpointing=True, dataloader_pin_memory=True, dataloader_num_workers=4, # torch_compile=True, # checkpointing: save_steps=370, optim="adafactor", # optim="adamw_torch", fp16=True, # push_to_hub=True, )
But I'm getting this error when running trainer.train();