bilal2vec / lm-training-research-project

MIT License
0 stars 0 forks source link

Run Adafactor with Square root Scheduler #1

Open JessicaLopezEspejel opened 3 years ago

JessicaLopezEspejel commented 3 years ago

Hello, Thank you for your work. I am interested in your AdaFactor implementation. I want to use the same training hyper-parameters from PEGASUS (https://arxiv.org/pdf/1912.08777.pdf) to train my model on PubMed dataset. In the article, the authors say that they use Square root learning decay. Can you give me an example on how can I use your code with this sheduler please?

I called your class in this way: optimizer = AdafactorOptimizer(learning_rate=5e-4) and at each training step, i use: optimizer.apply_gradients(zip(gradients, transformer.trainable_variables)) Did I use your class correclty? and is my call enough to replicate PEGASUS parameters?

Thanks a lot

bilal2vec commented 3 years ago

Hi sorry for taking so long to respond

The code that i use with adafactor here (https://github.com/bkkaggle/lm-training-research-project/blob/master/train_tfrecords.py#L228) is with a linear warmup for 10k steps and then decay over the rest of the steps. In practice i don't think there was really much of a difference between the performance when using linear and inv sqrt decay when I ran my experiments, but it shouldn't be too difficult to modify the code to make it work with that scheduler.

You might have better luck trying to use huggingface (https://github.com/huggingface/transformers) since I kinda threw this all together pretty messily.

As far as i know (at least when i did my work in may-ish) there isnt an official adafactor implementation that's easy to use yet (https://github.com/tensorflow/addons/issues/522) (https://github.com/tensorflow/tensor2tensor/issues/1482) but you can take a look at mine (https://github.com/bkkaggle/lm-training-research-project/blob/master/optimizers_tf.py) or the person who's implementation i used (https://github.com/bkkaggle/lm-training-research-project/blob/master/optimizers_tf.py)

Bilal