google-deepmind / alphageometry

Apache License 2.0
4.18k stars 469 forks source link

has anyone manage to run on colab? #116

Open Azeks0 opened 6 months ago

Azeks0 commented 6 months ago

I am trying to run alphageometry on colab but keep running into problems:

https://colab.research.google.com/drive/1RrTfa3O80QOFL68rXdtAyn1KHt0NCCB3?usp=sharing

The stack trace below excludes JAX-internal frames. The preceding is the original exception that occurred, unmodified.

The above exception was the direct cause of the following exception:

Traceback (most recent call last): File "/usr/lib/python3.10/runpy.py", line 196, in _run_module_as_main return _run_code(code, main_globals, None, File "/usr/lib/python3.10/runpy.py", line 86, in _run_code exec(code, run_globals) File "/content/alphageometry/alphageometry.py", line 651, in app.run(main) File "/usr/local/lib/python3.10/dist-packages/absl/app.py", line 308, in run _run_main(main, args) File "/usr/local/lib/python3.10/dist-packages/absl/app.py", line 254, in _run_main sys.exit(main(argv)) File "/content/alphageometry/alphageometry.py", line 637, in main model = get_lm(_CKPT_PATH.value, _VOCAB_PATH.value) File "/content/alphageometry/alphageometry.py", line 203, in get_lm return lm.LanguageModelInference(vocab_path, ckpt_init, mode='beam_search') File "/content/alphageometry/lminference.py", line 62, in init (tstate, , imodel, prngs) = trainer.initialize_model() File "/content/alphageometry/meliad_lib/meliad/training_loop.py", line 367, in initialize_model variables = model_init_fn(init_rngs, imodel.get_fake_input()) File "/content/alphageometry/models.py", line 68, in call self.decoder( File "/content/alphageometry/meliad_lib/meliad/transformer/decoder_stack.py", line 273, in call embeddings = embeddings.astype(self.dtype) File "/usr/local/lib/python3.10/dist-packages/jax/_src/numpy/lax_numpy.py", line 4952, in _astype dtypes.check_user_dtype_supported(dtype, "astype") TypeError: JAX only supports number and bool dtypes, got dtype bfloat16 in astype` and this is where it stops:

!python -m alphageometry \ --alsologtostderr \ --problems_file=$(pwd)/examples.txt \ --problem_name=orthocenter \ --mode=alphageometry \ --defs_file=$(pwd)/defs.txt \ --rules_file=$(pwd)/rules.txt \ --ckpt_path=$DATA \ --vocab_path=$DATA/geometry.757.model \ --gin_search_paths=$MELIAD_PATH/transformer/configs,$(pwd) \ --gin_file=base_htrans.gin \ --gin_file=size/medium_150M.gin \ --gin_file=options/positions_t5.gin \ --gin_file=options/lr_cosine_decay.gin \ --gin_file=options/seq_1024_nocache.gin \ --gin_file=geometry_150M_generate.gin \ --gin_param=DecoderOnlyLanguageModelGenerate.output_token_losses=True \ --gin_param=TransformerTaskConfig.batch_size=$BATCH_SIZE \ --gin_param=TransformerTaskConfig.sequence_length=128 \ --gin_param=Trainer.restore_state_variables=False \ --beam_size=$BEAM_SIZE \ --search_depth=$DEPTH \ Has anyone managed to make it work in colab so far? I didn't find anything related online.

I followed the tutorial provided on git but the error persists.

Imrauviel commented 6 months ago

I don't do:

!python -m pip install flax==0.5.3
!python -m pip install jaxlib==0.4.20

and so on, but instead:

!pip install --require-hashes -r requirements.txt
!pip install -r requirements.in