google-research / long-range-arena

Long Range Arena for Benchmarking Efficient Transformers
Apache License 2.0
720 stars 79 forks source link

Linear transformer performance #26

Closed maximzubkov closed 3 years ago

maximzubkov commented 3 years ago

Hello again!

Thank you for your work and for open-sourcing the codebase! I tried to run a listops experiment with Linear transformer and got results on the test that did not correspond to results proposed in the paper:

{"accuracy": 0.27000001072883606, "loss": 2.579371690750122, "perplexity": 13.188848495483398}

The model config was absolutely identical to the one used in transformer and the only thing I changed in the lra_benchmark/listops/train.py, was the following lines:

  if model_type == 'transformer':
    model = create_model(init_rng, transformer.TransformerEncoder, input_shape,
                         model_kwargs)
  elif model_type == 'linear_transformer':
      model = create_model(init_rng, linear_transformer.LinearTransformerEncoder, input_shape,
                           model_kwargs)
  else:
    raise ValueError('Model type not supported')

The experiments were run inside a docker nvidia/cuda:11.0-cudnn8-devel-ubuntu18.04 on 4 Tesla T4 with requirements:

jax>=0.2.4
flax>=0.2.2
ml-collections>=0.1.0
tensorboard>=2.3.0
tensorflow>=2.3.1
tensorflow-datasets>=4.0.1

and

pip install --upgrade jaxlib==0.1.65+cuda110 -f https://storage.googleapis.com/jax-releases/jax_releases.html

I can share the full Dockerfile if needed

cifkao commented 3 years ago

I've also experienced this. I tried training on the ListOps task multiple times, and sometimes I get an accuracy around 16, sometimes around 27.

See also #6, #14, #20.

maximzubkov commented 3 years ago

@cifkao Thank you!

vanzytay commented 3 years ago

This has been fixed in the recent update. Please try again and let us know if you have any trouble. Sorry for the delay. These emails have been going to my spam folder.