The stack trace below excludes JAX-internal frames. The preceding is the original exception that occurred, unmodified.
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "/usr/lib/python3.10/runpy.py", line 196, in _run_module_as_main
return _run_code(code, main_globals, None,
File "/usr/lib/python3.10/runpy.py", line 86, in _run_code
exec(code, run_globals)
File "/content/alphageometry/alphageometry.py", line 651, in
app.run(main)
File "/usr/local/lib/python3.10/dist-packages/absl/app.py", line 308, in run
_run_main(main, args)
File "/usr/local/lib/python3.10/dist-packages/absl/app.py", line 254, in _run_main
sys.exit(main(argv))
File "/content/alphageometry/alphageometry.py", line 637, in main
model = get_lm(_CKPT_PATH.value, _VOCAB_PATH.value)
File "/content/alphageometry/alphageometry.py", line 203, in get_lm
return lm.LanguageModelInference(vocab_path, ckpt_init, mode='beam_search')
File "/content/alphageometry/lminference.py", line 62, in init
(tstate, , imodel, prngs) = trainer.initialize_model()
File "/content/alphageometry/meliad_lib/meliad/training_loop.py", line 367, in initialize_model
variables = model_init_fn(init_rngs, imodel.get_fake_input())
File "/content/alphageometry/models.py", line 68, in call
self.decoder(
File "/content/alphageometry/meliad_lib/meliad/transformer/decoder_stack.py", line 273, in call
embeddings = embeddings.astype(self.dtype)
File "/usr/local/lib/python3.10/dist-packages/jax/_src/numpy/lax_numpy.py", line 4952, in _astype
dtypes.check_user_dtype_supported(dtype, "astype")
TypeError: JAX only supports number and bool dtypes, got dtype bfloat16 in astype`
and this is where it stops:
!python -m alphageometry \
--alsologtostderr \
--problems_file=$(pwd)/examples.txt \
--problem_name=orthocenter \
--mode=alphageometry \
--defs_file=$(pwd)/defs.txt \
--rules_file=$(pwd)/rules.txt \
--ckpt_path=$DATA \
--vocab_path=$DATA/geometry.757.model \
--gin_search_paths=$MELIAD_PATH/transformer/configs,$(pwd) \
--gin_file=base_htrans.gin \
--gin_file=size/medium_150M.gin \
--gin_file=options/positions_t5.gin \
--gin_file=options/lr_cosine_decay.gin \
--gin_file=options/seq_1024_nocache.gin \
--gin_file=geometry_150M_generate.gin \
--gin_param=DecoderOnlyLanguageModelGenerate.output_token_losses=True \
--gin_param=TransformerTaskConfig.batch_size=$BATCH_SIZE \
--gin_param=TransformerTaskConfig.sequence_length=128 \
--gin_param=Trainer.restore_state_variables=False \
--beam_size=$BEAM_SIZE \
--search_depth=$DEPTH \
Has anyone managed to make it work in colab so far? I didn't find anything related online.
I followed the tutorial provided on git but the error persists.
I am trying to run alphageometry on colab but keep running into problems:
https://colab.research.google.com/drive/1RrTfa3O80QOFL68rXdtAyn1KHt0NCCB3?usp=sharing
The stack trace below excludes JAX-internal frames. The preceding is the original exception that occurred, unmodified.
The above exception was the direct cause of the following exception:
Traceback (most recent call last): File "/usr/lib/python3.10/runpy.py", line 196, in _run_module_as_main return _run_code(code, main_globals, None, File "/usr/lib/python3.10/runpy.py", line 86, in _run_code exec(code, run_globals) File "/content/alphageometry/alphageometry.py", line 651, in
app.run(main)
File "/usr/local/lib/python3.10/dist-packages/absl/app.py", line 308, in run
_run_main(main, args)
File "/usr/local/lib/python3.10/dist-packages/absl/app.py", line 254, in _run_main
sys.exit(main(argv))
File "/content/alphageometry/alphageometry.py", line 637, in main
model = get_lm(_CKPT_PATH.value, _VOCAB_PATH.value)
File "/content/alphageometry/alphageometry.py", line 203, in get_lm
return lm.LanguageModelInference(vocab_path, ckpt_init, mode='beam_search')
File "/content/alphageometry/lminference.py", line 62, in init
(tstate, , imodel, prngs) = trainer.initialize_model()
File "/content/alphageometry/meliad_lib/meliad/training_loop.py", line 367, in initialize_model
variables = model_init_fn(init_rngs, imodel.get_fake_input())
File "/content/alphageometry/models.py", line 68, in call
self.decoder(
File "/content/alphageometry/meliad_lib/meliad/transformer/decoder_stack.py", line 273, in call
embeddings = embeddings.astype(self.dtype)
File "/usr/local/lib/python3.10/dist-packages/jax/_src/numpy/lax_numpy.py", line 4952, in _astype
dtypes.check_user_dtype_supported(dtype, "astype")
TypeError: JAX only supports number and bool dtypes, got dtype bfloat16 in astype`
and this is where it stops:
!python -m alphageometry \ --alsologtostderr \ --problems_file=$(pwd)/examples.txt \ --problem_name=orthocenter \ --mode=alphageometry \ --defs_file=$(pwd)/defs.txt \ --rules_file=$(pwd)/rules.txt \ --ckpt_path=$DATA \ --vocab_path=$DATA/geometry.757.model \ --gin_search_paths=$MELIAD_PATH/transformer/configs,$(pwd) \ --gin_file=base_htrans.gin \ --gin_file=size/medium_150M.gin \ --gin_file=options/positions_t5.gin \ --gin_file=options/lr_cosine_decay.gin \ --gin_file=options/seq_1024_nocache.gin \ --gin_file=geometry_150M_generate.gin \ --gin_param=DecoderOnlyLanguageModelGenerate.output_token_losses=True \ --gin_param=TransformerTaskConfig.batch_size=$BATCH_SIZE \ --gin_param=TransformerTaskConfig.sequence_length=128 \ --gin_param=Trainer.restore_state_variables=False \ --beam_size=$BEAM_SIZE \ --search_depth=$DEPTH \ Has anyone managed to make it work in colab so far? I didn't find anything related online.
I followed the tutorial provided on git but the error persists.