flax Version: 0.9.0
jax Version: 0.4.34
jaxlib Version: 0.4.34
on my laptop, running in CPU only mode.
flax Version: 0.8.5
jax Version: 0.4.33
jaxlib Version: 0.4.33
on google collab
Problem you have encountered:
Gradient was not calculated with the given loss function
What you expected to happen:
Gradient to be calculated, getting the same result I get when the code makes no reference to sparsity.
Logs, error messages, etc:
Traceback (most recent call last):
File "D:\MLEnv\Lib\site-packages\jax\_src\core.py", line 824, in __getattr__
attr = getattr(self.aval, name)
^^^^^^^^^^^^^^^^^^^^^^^^
AttributeError: 'ConcreteArray' object has no attribute 'data'
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "D:\MLEnv\Other stuff\issue.py", line 74, in <module>
grads = nnx.grad(Loss)(state.model)
^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "D:\MLEnv\Lib\site-packages\flax\nnx\nnx\graph.py", line 1158, in update_context_manager_wrapper
return f(*args, **kwargs)
^^^^^^^^^^^^^^^^^^
File "D:\MLEnv\Lib\site-packages\flax\nnx\nnx\transforms\autodiff.py", line 164, in grad_wrapper
fn_out = gradded_fn(*pure_args)
^^^^^^^^^^^^^^^^^^^^^^
File "D:\MLEnv\Lib\site-packages\flax\nnx\nnx\transforms\autodiff.py", line 86, in __call__
out = self.f(*args)
^^^^^^^^^^^^^
File "D:\MLEnv\Other stuff\issue.py", line 73, in <lambda>
Loss=lambda model: compute_E(H, model)
^^^^^^^^^^^^^^^^^^^
File "D:\MLEnv\Other stuff\issue.py", line 65, in compute_E
N=jax.experimental.sparse.bcoo_reduce_sum(N,axes=(0))
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "D:\MLEnv\Lib\site-packages\jax\experimental\sparse\bcoo.py", line 2125, in bcoo_reduce_sum
mat.data, mat.indices, spinfo=mat._info, axes=axes)
^^^^^^^^
AttributeError: JVPTracer has no attribute data
--------------------
For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.
System information
flax Version: 0.9.0 jax Version: 0.4.34 jaxlib Version: 0.4.34 on my laptop, running in CPU only mode.
flax Version: 0.8.5 jax Version: 0.4.33 jaxlib Version: 0.4.33 on google collab
Problem you have encountered:
Gradient was not calculated with the given loss function
What you expected to happen:
Gradient to be calculated, getting the same result I get when the code makes no reference to sparsity.
Logs, error messages, etc:
Steps to reproduce:
colab
The collab link also has the code with dense matrices which works.