google-research / long-range-arena

Long Range Arena for Benchmarking Efficient Transformers
Apache License 2.0
710 stars 77 forks source link

Current code doesn't work with latest flax version and run on CPU only #48

Open ynahshan opened 2 years ago

nurullahsevim commented 2 years ago

I have the same issue. Any update on this?

ynahshan commented 2 years ago

Hope the authors will fix this issue otherwise, this repo will be useless. Currently, to bypass it I use this PyTorch implementation of LRA. https://github.com/pkuzengqi/Skyformer

fecet commented 1 year ago

This repo use depreciated optim api so it cannot work with latest FLAX, but it can run at GPU with old FLAX, in my case, 0.3.6

kpe commented 1 year ago

@MostafaDehghani @ynahshan - any plans to update the repo to a newer jax/flax version?

fecet commented 1 year ago

@MostafaDehghani @ynahshan - any plans to update the repo to a newer jax/flax version?

I do some modification for linformer so it can work with newest flax and the remaining could be done in similar ways. You can pick it up if you are interested https://github.com/fecet/long-range-arena

DaShenZi721 commented 1 year ago

@fecet Hello! Sorry to bother you. Have you ever encountered the following problem? I think it may be related to the version of flax.

Traceback (most recent call last):
  File "lra_benchmarks/listops/train.py", line 28, in <module>
    from flax.deprecated import nn
ModuleNotFoundError: No module named 'flax.deprecated'
DaShenZi721 commented 1 year ago

@fecet This is my setting:

fecet commented 1 year ago

@DaShenZi721 'flax.deprecated' only exist in some certain version to temporary store those code would be deprecated, so I guess your version is too old, you can check if deprecated exist in the source code of your flax

DaShenZi721 commented 1 year ago

@fecet Thanks so much! I try to replace from flax.deprecated import nn with from flax import nn, and it works!

AlexKay28 commented 1 year ago

@fecet Maybe it's a good idea to define strict version of python packages in requirements and other tools'? Or even create a docker image especially for this repo?

Jax, flax libraries are changing very fast and current scripts outdated quite fast .

Right now I can't even launch commands from README =(

fecet commented 1 year ago

@fecet Maybe it's a good idea to define strict version of python packages in requirements and other tools'? Or even create a docker image especially for this repo?

Jax, flax libraries are changing very fast and current scripts outdated quite fast .

Right now I can't even launch commands from README =(

Apparently, Google has already abandoned this project, and I (and we) are powerless to do anything about it :smile:

DaShenZi721 commented 1 year ago

Hello @AlexKay28! @fecet is right, Google's team has already given up on maintaining this project, but I was able to run it successfully. I replaced from flax.deprecated import nn with from flax import nn. My python version is 3.8.16 Below is my conda environment.

absl-py==1.4.0
appdirs==1.4.4
astunparse==1.6.3
attrs==22.2.0
beautifulsoup4==4.12.2
blessed==1.20.0
cachetools==4.2.4
certifi @ file:///croot/certifi_1671487769961/work/certifi
charset-normalizer==3.1.0
click==8.1.3
cloudpickle==2.2.1
conda-pack==0.6.0
contextlib2==21.6.0
cycler==0.11.0
decorator==5.1.1
dill==0.3.6
dm-tree==0.1.8
docker-pycreds==0.4.0
einops==0.6.1
filelock==3.11.0
fire==0.5.0
flatbuffers==2.0.7
flax==0.3.0
fonttools==4.39.0
future==0.18.3
gast==0.3.3
gdown==4.7.1
gin-config==0.5.0
gitdb==4.0.10
GitPython==3.1.31
google-auth==1.35.0
google-auth-oauthlib==0.4.6
google-pasta==0.2.0
googleapis-common-protos==1.58.0
gpustat==1.0.0
grpcio==1.51.3
h5py==2.10.0
idna==3.4
importlib-metadata==6.0.0
importlib-resources==5.12.0
jax==0.2.16
jax-smi==1.0.3
jaxlib==0.1.67+cuda111
keras==2.11.0
Keras-Preprocessing==1.1.2
kiwisolver==1.4.4
libclang==15.0.6.1
Markdown==3.4.1
MarkupSafe==2.1.2
matplotlib==3.5.3
ml-collections==0.1.1
msgpack==1.0.4
numpy==1.21.0
nvidia-ml-py==11.495.46
oauthlib==3.2.2
opt-einsum==3.3.0
packaging==23.0
pathtools==0.1.2
Pillow==9.4.0
promise==2.3
protobuf==3.19.6
psutil==5.9.4
pyasn1==0.4.8
pyasn1-modules==0.2.8
pyparsing==3.0.9
PySocks==1.7.1
python-dateutil==2.8.2
PyYAML==6.0
requests==2.28.2
requests-oauthlib==1.3.1
rsa==4.9
scipy==1.9.3
sentry-sdk==1.16.0
setproctitle==1.3.2
six==1.16.0
smmap==5.0.0
soupsieve==2.4.1
tensorboard==2.11.2
tensorboard-data-server==0.6.1
tensorboard-plugin-wit==1.8.1
tensorflow==2.11.0
tensorflow-datasets==4.0.0
tensorflow-estimator==2.11.0
tensorflow-hub==0.12.0
tensorflow-io-gcs-filesystem==0.31.0
tensorflow-metadata==1.12.0
tensorflow-probability==0.19.0
tensorflow-text==2.11.0
termcolor==2.2.0
tqdm==4.65.0
typing_extensions==4.5.0
urllib3==1.26.14
wandb==0.13.11
wcwidth==0.2.6
Werkzeug==2.2.3
wrapt==1.15.0
zipp==3.15.0
arneeichholtz commented 1 year ago

Yes, @DaShenZi721 I had the same problem with the import flax.deprecated. Trying to run the byte-level text classification task.

Your option from flax import nn does not work for me, however. I found from flax import linen as nn to work and installing jax==0.4.0, jaxlib==0.4.0 and flax==0.5.3. This is because GlobalDeviceArray is replaced by jax.Array from jax==0.4.1, so the older version is compatible with GlobalDeviceArray in the code.

Then the dataset is successfully loaded, the config settings are printed, but when calling create_model() (from text_classification/train.py) the following error is raised:

File "lra_benchmarks/text_classification/train.py", line 66, in _create_model
     module = flax_module.partial(**model_kwargs)
 AttributeError: type object 'LinearTransformerEncoder' has no attribute 'partial'

Do you know a way around this? How far have you gotten? Any help is greatly appreciated! Thanks, Arne

DaShenZi721 commented 1 year ago

Hello, @arneeichholtz! I have encountered this error before. It is still a version issue in Flax. In version 0.4.0 of Flax, the usage of Module.partial has changed. You can modify your code according to the link below. https://flax.readthedocs.io/en/latest/advanced_topics/linen_upgrade_guide.html#module-partial-inside-other-modules

arneeichholtz commented 1 year ago

Thanks for the response! The partial call is resolved but now I'm trying to modify the nn.stochastic call, as this is also deprecated. The link you sent gives an explanation how to do it but I can't really figure it out. Are you running into the same problem or have resolved it? Thanks, Arne