google-research / long-range-arena

Long Range Arena for Benchmarking Efficient Transformers
Apache License 2.0
720 stars 79 forks source link

Different hyper-parameters used for different models in image task. #34

Open mlpen opened 3 years ago

mlpen commented 3 years ago

Hi,

I found that different hyper-parameters (number of layers, dimension, etc.) are used for different models. Can you clarify how the baselines are compared?

For example, https://github.com/google-research/long-range-arena/blob/main/lra_benchmarks/image/configs/cifar10/longformer_base.py

config.model_type = "longformer"
config.model.num_layers = 4
config.model.emb_dim = 128
config.model.qkv_dim = 64
config.model.mlp_dim = 128
config.model.num_heads = 4
config.model.classifier_pool = "MEAN"

https://github.com/google-research/long-range-arena/blob/main/lra_benchmarks/image/configs/cifar10/performer_base.py

config.model_type = "performer"
config.model.num_layers = 1
config.model.emb_dim = 128
config.model.qkv_dim = 64
config.model.mlp_dim = 128
config.model.num_heads = 8
config.model.classifier_pool = "CLS"

https://github.com/google-research/long-range-arena/blob/main/lra_benchmarks/image/configs/cifar10/reformer_base.py

config.model_type = "reformer"
config.model.num_layers = 4
config.model.emb_dim = 64
config.model.qkv_dim = 32
config.model.mlp_dim = 64
config.model.num_heads = 8
config.model.classifier_pool = "CLS"
vanzytay commented 3 years ago

@MostafaDehghani for clarity on the image configs.

MostafaDehghani commented 3 years ago

@mlpen We had extensive hp search for every single model to make sure that we have best possible results from each. Especially for the CIFAR task, given that the results of different models are close, we wanted to make sure we have a rather large grid for searching the hp for each model separately. So you can see different values for number of layers, number of heads, etc. And basically we prioritized getting best possible result over keeping the number of trainale parameters similar across models. Hope this answers your question. Let us know if you have any issue reproducing the result or if by any chance you ended up with an hp for any of these models that gives you better results than what we reported in the paper.

alexmathfb commented 3 years ago

We had extensive hp search for every single model to make sure that we have best possible results from each.

Was the hyperparameters from the hp search used to make Table 1, or was the hp search done after Table 1? I'm confused because article says all Transformer models used the same fixed hyperparameters, but the result of hp search gave different hyperparameters

" The large search space motivates us to follow a set of fixed hyperparameters (number of layers, heads, embedding dimensions, etc) for all models. "