Hi, When using the SMURF, I find a bug in the line " jax.value_and_grad(mod, has_aux=True, argnums=0)".
When runing this line, the error is "python3.9/site-packages/jax/_src/api_util.py", line 583, in shaped_abstractify return _shaped_abstractify_handlerstype(x) KeyError: <class 'alphafold.common.protein.Protein'>" ,
but the 'alphafold.common.protein.Protein' class is so simple, so which reason cause it .
Hi, When using the SMURF, I find a bug in the line " jax.value_and_grad(mod, has_aux=True, argnums=0)".
When runing this line, the error is "python3.9/site-packages/jax/_src/api_util.py", line 583, in shaped_abstractify return _shaped_abstractify_handlerstype(x) KeyError: <class 'alphafold.common.protein.Protein'>" ,
but the 'alphafold.common.protein.Protein' class is so simple, so which reason cause it .
I am very eager for your reply