Open jiwei08 opened 5 months ago
I have the same error
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "/alphageometry/lm_inference_test.py", line 61, in setUpClass
cls.loaded_lm = lm.LanguageModelInference(
File "alphageometry/lm_inference.py", line 62, in __init__
(tstate, _, imodel, prngs) = trainer.initialize_model()
File "/alphageometry/meliad_lib/meliad/training_loop.py", line 367, in initialize_model
variables = model_init_fn(init_rngs, imodel.get_fake_input())
File "/alphageometry/models.py", line 68, in __call__
self.decoder(
File "/alphageometry/meliad_lib/meliad/transformer/decoder_stack.py", line 273, in __call__
embeddings = embeddings.astype(self.dtype)
File "/4math/lib/python3.10/site-packages/jax/_src/numpy/lax_numpy.py", line 4952, in _astype
dtypes.check_user_dtype_supported(dtype, "astype")
TypeError: JAX only supports number and bool dtypes, got dtype bfloat16 in astype
DDAR can be executed successfully. However, when attempting to test
alphageometry
, I encountered aTypeError
as follows:I have attempted to identify the bug but have been unsuccessful so far. Could anyone provide assistance? Any help would be greatly appreciated!