tatp22 / linformer-pytorch

My take on a practical implementation of Linformer for Pytorch.
https://arxiv.org/pdf/2006.04768.pdf
MIT License
400 stars 36 forks source link

Any result on any benchmark? #13

Open twangnh opened 4 years ago

twangnh commented 4 years ago

Hi, thanks for sharing the implementation, could you pls share some reproduction results, possibly on some benchmarks?

tatp22 commented 4 years ago

Hi @twangnh,

I ran the test with pretrain_tutuorial_lm.py on this tutorial, and after 7 epochs, I got a ppl of about 348. This is, of course, higher than what the final ppl is on the website (227), even after 7 epochs. However, this is due to the fact that the model that I ran is only the encoder with no masking, while the model on this site is an encoder/decoder with masking. Since I only got the masking done recently, I will attempt to run another test soon with both an encoder and a decoder, but that being said, I don't know how the results will be skewed, because my masking method is different than what is proposed in the Attention is all you need paper. See #11 for more details on the masking.

That being said, if there is a benchmark that only uses the encoder or something, that could be readily tested right now, basically one that relies on only the self attention mechanism or something. I could post the results of that.

phongnhhn92 commented 4 years ago

Hi @twangnh,

I ran the test with pretrain_tutuorial_lm.py on this tutorial, and after 7 epochs, I got a ppl of about 348. This is, of course, higher than what the final ppl is on the website (227), even after 7 epochs. However, this is due to the fact that the model that I ran is only the encoder with no masking, while the model on this site is an encoder/decoder with masking. Since I only got the masking done recently, I will attempt to run another test soon with both an encoder and a decoder, but that being said, I don't know how the results will be skewed, because my masking method is different than what is proposed in the Attention is all you need paper. See #11 for more details on the masking.

That being said, if there is a benchmark that only uses the encoder or something, that could be readily tested right now, basically one that relies on only the self attention mechanism or something. I could post the results of that.

Hey @tatp22 , I think you have done tremendous job for implementing this. I am looking for your updated results with the updated masking to have a fair comparison with Transformer architecture which has been using in this tutorial. I think you can also test your implementation with WikiText-103 dataset and compare with other architecture here.

tatp22 commented 3 years ago

An update: I am attempting to benchmark right now, but the thing is, the tutorial here assumes a different data layout than what my LinformerLM class expects. The data is passed in as (seq_len, batch_size) instead of (batch_size, seq_len), and so, the results were naturally worse than expected...

This was working well on my non LM tasks, so I was wondering why it was performing rather poorly here. I am looking at how to fix this so the results are consistent. Look for an ultimatum on this soon.

phongnhhn92 commented 3 years ago

I would be surprised if the data format is the problem because we can just use permute() function to swap axes.

tatp22 commented 3 years ago

So I ran the benchmark for wikitext-2 using this tutorial, and it actually seems like it performs comparatively with the standard transformer (at least, the results they presented on the website). However, when it comes to the testing data, this generalizes poorly.

I don't think this is an issue with the Linformer. Rather, I think that this might just be the way that the hyperparameters are constructed. I took my inspiration for this repo from the existing Sinkhorn Transformer, and I actually did a run with it, which performed similarly. I actually tried lowering the LR, and the results were better. I'll leave my findings in the table below, and if anyone can come up with better results, please let me know. All runs are done for 9 epochs, and the hyperparams closely follow what is on that tutorial.