google-research / long-range-arena

Long Range Arena for Benchmarking Efficient Transformers
Apache License 2.0
710 stars 77 forks source link

Error when run document retrival #42

Open weixuansun opened 2 years ago

weixuansun commented 2 years ago

Hi, thanks for the great code, I am having some issues when trying to run the document retrieval tasks. I got following issue when trying to run matching/train.py using the base transformer network:

Traceback (most recent call last): File "lra_benchmarks/matching/train.py", line 320, in <module> app.run(main) File "/mnt/lustre/share/spring/conda_envs/miniconda3/envs/s0.3.4/lib/python3.6/site-packages/absl/app.py", line 303, in run _run_main(main, args) File "/mnt/lustre/share/spring/conda_envs/miniconda3/envs/s0.3.4/lib/python3.6/site-packages/absl/app.py", line 251, in _run_main sys.exit(main(argv)) File "lra_benchmarks/matching/train.py", line 197, in main init_rng, input_shape) File "/mnt/lustre/sunweixuan/long-range-arena/lra_benchmarks/utils/train_utils.py", line 52, in get_model *create_model_args) TypeError: create_model() missing 1 required positional argument: 'input2_shape'

It seems that it took two 'input_shape' then modified? see here: https://github.com/google-research/long-range-arena/commit/093bfc64b8e5ec7813aad7be1b24b5e2b730a9bc.

Then, when I input two 'input_shape1' and 'input_shape2', above issue is solved but I got a new error:

Traceback (most recent call last): File "lra_benchmarks/matching/train.py", line 321, in <module> app.run(main) File "/mnt/lustre/share/spring/conda_envs/miniconda3/envs/s0.3.4/lib/python3.6/site-packages/absl/app.py", line 303, in run _run_main(main, args) File "/mnt/lustre/share/spring/conda_envs/miniconda3/envs/s0.3.4/lib/python3.6/site-packages/absl/app.py", line 251, in _run_main sys.exit(main(argv)) File "lra_benchmarks/matching/train.py", line 198, in main init_rng, input_shape, input_shape) File "/mnt/lustre/sunweixuan/long-range-arena/lra_benchmarks/utils/train_utils.py", line 52, in get_model *create_model_args) File "lra_benchmarks/matching/train.py", line 71, in create_model return _create_model(key) File "lra_benchmarks/matching/train.py", line 67, in _create_model (input2_shape, jnp.float32)]) File "/mnt/lustre/sunweixuan/.local/lib/python3.6/site-packages/flax/nn/base.py", line 238, in wrapper return super_fn(*args, **kwargs) File "/mnt/lustre/sunweixuan/.local/lib/python3.6/site-packages/flax/nn/base.py", line 238, in wrapper return super_fn(*args, **kwargs) File "/mnt/lustre/sunweixuan/.local/lib/python3.6/site-packages/flax/nn/base.py", line 536, in init_by_shape return jax_utils.partial_eval_by_shape(lazy_init, input_specs) File "/mnt/lustre/sunweixuan/.local/lib/python3.6/site-packages/flax/jax_utils.py", line 116, in partial_eval_by_shape _, out_pvals, _ = pe.trace_to_jaxpr(f_flat, in_pvals) File "/mnt/lustre/sunweixuan/.local/lib/python3.6/site-packages/flax/jax_utils.py", line 110, in <lambda> f = lambda *inputs: fn(*inputs, *args, **kwargs) File "/mnt/lustre/sunweixuan/.local/lib/python3.6/site-packages/flax/nn/base.py", line 533, in lazy_init return init_fn() File "/mnt/lustre/sunweixuan/.local/lib/python3.6/site-packages/flax/nn/base.py", line 527, in init_fn return cls.init(_rng, *(inputs + args), name=name, **kwargs) File "/mnt/lustre/sunweixuan/.local/lib/python3.6/site-packages/flax/nn/base.py", line 238, in wrapper return super_fn(*args, **kwargs) File "/mnt/lustre/sunweixuan/.local/lib/python3.6/site-packages/flax/nn/base.py", line 238, in wrapper return super_fn(*args, **kwargs) File "/mnt/lustre/sunweixuan/.local/lib/python3.6/site-packages/flax/nn/base.py", line 489, in init y = instance.apply(*args, **kwargs) TypeError: apply() got multiple values for argument 'vocab_size'

Could you let me know what's the possible solution?

wuhaixu2016 commented 2 years ago

I came across the same problem. Does anybody have some suggestions?

wuhaixu2016 commented 2 years ago

Hi, I have solved this problem with the following operations: (1) './matching/train.py' model = train_utils.get_model(model_type, create_model, model_kwargs, init_rng, input_shape, input_shape) (2) './utils/train_utils.py' change the model as 'Dual'

Chen-Chang commented 1 year ago

Hi, I have solved this problem with the following operations: (1) './matching/train.py' model = train_utils.get_model(model_type, create_model, model_kwargs, init_rng, input_shape, input_shape) (2) './utils/train_utils.py' change the model as 'Dual'

Hi, could you please elaborate how to debug the problem? What is the meaning of 'Dual'?