lucidrains / enformer-pytorch

Implementation of Enformer, Deepmind's attention network for predicting gene expression, in Pytorch
MIT License
434 stars 81 forks source link

Fine-tuning without freezing transformer parameter leads to poor performance #16

Closed Zehui127 closed 1 year ago

Zehui127 commented 1 year ago

Dear team,

As far as I can understand, the current fine-tuning feature will fix transformer model and only fine-tune the linear heads. I wonder has anyone try to fine-tuning the whole model without fixing the transformer parameter? I tried, but the performance is not good; for example, when I simply fine-tune the model on the original basenji training set, the correlation score decreased from 0.6 to 0.1. I wonder has anyone encounter the similar issue and what is the potential reason for this?

lucidrains commented 1 year ago

@Zehui127 it is difficult because of the batchnorms

however, recently it seems like researchers are giving good testimonies about using LoRA

i could add this as an option, on the condition that you let me know your results once you try it?

Zehui127 commented 1 year ago

@lucidrains Thanks for responding to this. It will be much appreciated if there is an option for doing tune-fining without freezing the transformer parameter. I will definitely let you know the result if you make the change.

lucidrains commented 1 year ago

@Zehui127 oh, it has been a while and i have forgotten what i had built

have you tried this option yet?

i'm not really caught up on the fine tuning literature, but my impression is that you don't really need to unfreeze the entire network

lucidrains commented 1 year ago

@Zehui127 LoRA should be quite good, i see people doing dreambooth fine tuning in stable diffusion using that technique

Zehui127 commented 1 year ago

@lucidrains I actually have looked at the code of this option; it basically froze the transformer and convolutional layer; only add a linear layer to multiply with the linear heads. We don't want to do it because we are trying to do end-2-end training with Enformer and the downstream model. I suppose in the original Tensorflow Enformer Implementation, it unfroze the entire network. Is that possible know how to avoid the issue with the batchnorms? I can also try to implement it if you provide some hints

lucidrains commented 1 year ago

that option unfreezes the layernorms of the transformer, a popular technique. there's also another option to unfreeze the penultimate layers of the transformer

lucidrains commented 1 year ago

oh actually nevermind, I never integrated that last N layers fine-tuning thing

what is your batch size when fine-tuning?

lucidrains commented 1 year ago

fine-tuning is still more art than science. stuff like catastrophic forgetting happens often when you unfreeze the entire network

Zehui127 commented 1 year ago

As for now, I only did a fine-tuning with single V100 (30 GB), and Batch-Size of 2 examples; Larger Batch Size will trigger OUT_OF_Memory on cuda.

lucidrains commented 1 year ago

yeah, batch norm won't work unless your batch size is at least 32 🤷‍♂️

batchnorm, why do you have to work so well shakes fist

lucidrains commented 1 year ago

yeah, I think unfreezing the transformer may be your best bet, since it is free of batch normalization (while keeping conv stem frozen)

lucidrains commented 1 year ago

I'll take a look tomorrow morning and see if I can easily squeeze in the unfreezing of last N layers of transformer setting, so you can play around with that

lucidrains commented 1 year ago

may I ask whose lab you are working for?

Zehui127 commented 1 year ago

So which means that the following will work right?

  1. if I train the model with batch size more than 32, and unfroze the whole network (conv + transformer)
  2. If I train the model with batch size less than 32, and froze the conv + unfroze the transformer
Zehui127 commented 1 year ago

I'm a PhD in Imperial College London

lucidrains commented 1 year ago

well, you forget that catastrophic forgetting can still happen, esp if you unfreeze the whole network at once. recommend you google for literature on that and see what the latest bag of tricks are

higher batch sizes will help with the batch norm yeah

Zehui127 commented 1 year ago

Thanks for your kind help!

lucidrains commented 1 year ago

ahh ok cool, yeah then by your tomorrow afternoon, I can probably get the unfreezing of entire or last N layers of transformers finished

lucidrains commented 1 year ago

no problem, I want this repository to be useful 🌝

lucidrains commented 1 year ago

@Zehui127 Hi Zehui

reviewed the code this morning and noticed that i already do freeze the batchnorms, so if you are using any one of the finetuning adapters, the batchnorm should not be the issue (as long as you are gradient accumulating properly and not training on batch size of 2). in the case you did things correctly, likely issue would be catastrophic forgetting

i've finished integrating the feature for freezing the entire network but the last N layers of the transformer. you can use it by setting this keyword argument on adapter forward to an integer

lucidrains commented 1 year ago

@Zehui127 which adapter are you using for the finetuning? the head adapter?

Zehui127 commented 1 year ago

@lucidrains Thanks for helping on this. I was using the HeadAdapterWrapper before. I will try and let you know the result with last N layers fine-tuning.

lucidrains commented 1 year ago

@Zehui127 were you gradient accumulating when training on batch size of 2? wondering if i should offer a training wrapper that takes care of this too (noticed i actually left myself a todo for researchers like you, but never completed it)

lucidrains commented 1 year ago

@Zehui127 also, have you tried the discrete key / value bottleneck feature? https://arxiv.org/abs/2207.11240 i quite liked the paper, but never tried it out on a task yet

Zehui127 commented 1 year ago

@lucidrains , It looks like a very interesting paper, I will try and let you know the result. For gradient accumulation, previously, I did use it in some runs, but it didn't have much improvement. I'm refactoring the code at this moment to allow larger effective batch size. Hopefully, with all these changes, the performance could be improved. Thanks again.

lucidrains commented 1 year ago

@Zehui127 ok, keep me in the know!

lucidrains commented 1 year ago

@Zehui127 you'll need gradient accumulation for large effective batch size, i'll build it into some training wrapper at some point. it is needlessly confusing for the uninitiated