google-deepmind / alphageometry

Apache License 2.0
4.1k stars 460 forks source link

TypeError is raised when running alphageometry #108

Open jiwei08 opened 5 months ago

jiwei08 commented 5 months ago

DDAR can be executed successfully. However, when attempting to test alphageometry, I encountered a TypeError as follows:

TypeError: JAX only supports number and bool dtypes, got dtype bfloat16 in astype

I have attempted to identify the bug but have been unsuccessful so far. Could anyone provide assistance? Any help would be greatly appreciated!

yunyin commented 4 months ago

I have the same error

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/alphageometry/lm_inference_test.py", line 61, in setUpClass
    cls.loaded_lm = lm.LanguageModelInference(
  File "alphageometry/lm_inference.py", line 62, in __init__
    (tstate, _, imodel, prngs) = trainer.initialize_model()
  File "/alphageometry/meliad_lib/meliad/training_loop.py", line 367, in initialize_model
    variables = model_init_fn(init_rngs, imodel.get_fake_input())
  File "/alphageometry/models.py", line 68, in __call__
    self.decoder(
  File "/alphageometry/meliad_lib/meliad/transformer/decoder_stack.py", line 273, in __call__
    embeddings = embeddings.astype(self.dtype)
  File "/4math/lib/python3.10/site-packages/jax/_src/numpy/lax_numpy.py", line 4952, in _astype
    dtypes.check_user_dtype_supported(dtype, "astype")
TypeError: JAX only supports number and bool dtypes, got dtype bfloat16 in astype