Amelie-Schreiber / esm2_loras

Trying to train LoRAs for ESM-2
https://huggingface.co/AmelieSchreiber/esm2_t6_8M_UR50D_lora_rna_binding_sites
7 stars 0 forks source link

ValueError: EsmForTokenClassification does not support gradient checkpointing #2

Open suresh-pokharel opened 7 months ago

suresh-pokharel commented 7 months ago

Hi Amelie, I am looking to finetune ESM model on my small downstream task. I am following your article https://huggingface.co/blog/AmelieSchreiber/esm2-ptm.

I am getting the following error on gradient_checkpointing_enable() function that you might already have a solution/suggestions.

    model.gradient_checkpointing_enable() # SP commented
    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/sureshp/anaconda3/envs/qlora/lib/python3.11/site-packages/transformers/modeling_utils.py", line 1631, in gradient_checkpointing_enable
    raise ValueError(f"{self.__class__.__name__} does not support gradient checkpointing.")
ValueError: EsmForTokenClassification does not support gradient checkpointing.

Thanks.