mlpen / Nystromformer

Apache License 2.0
356 stars 41 forks source link

Retrieval accuracy different from official JAX/FLAX implementation #11

Open cwq159 opened 2 years ago

cwq159 commented 2 years ago

I wonder why the Retrieval accuracy is almost 20% higher than the official JAX/FLAX implementation. As the paper says, "While we achieve consistent results reported in (Tay et al. 2020) for most tasks in our PyTorch reimplementation, the performance on Retrieval task is higher for all models following the hyperparameters in (Tay et al. 2020)." Is there any difference aside from the hyperparameters?

mlpen commented 2 years ago

Hi, sorry for the late response. We have actually asked the authors of LRA about this issue, but the problem is not completely resolved. https://github.com/google-research/long-range-arena/issues/18 We suspect that the difference in hyper-parameters might be one of the reasons. However, when I checked the latest repo of LRA a few minutes ago, we are still not clear what hyper-parameters are exactly and how baselines are compared in the original paper. We used the data processing code in LRA repo and only rewrote the implementation for the model and training. So, the answer is still not clear.