google / flax

Flax is a neural network library for JAX that is designed for flexibility.
https://flax.readthedocs.io
Apache License 2.0
6.15k stars 648 forks source link

flax does not work with sparse matrices #4366

Closed Prashant-Shekhar-Rao closed 1 week ago

Prashant-Shekhar-Rao commented 2 weeks ago

System information

flax Version: 0.9.0 jax Version: 0.4.34 jaxlib Version: 0.4.34 on my laptop, running in CPU only mode.

flax Version: 0.8.5 jax Version: 0.4.33 jaxlib Version: 0.4.33 on google collab

Problem you have encountered:

Gradient was not calculated with the given loss function

What you expected to happen:

Gradient to be calculated, getting the same result I get when the code makes no reference to sparsity.

Logs, error messages, etc:

Traceback (most recent call last):
  File "D:\MLEnv\Lib\site-packages\jax\_src\core.py", line 824, in __getattr__
    attr = getattr(self.aval, name)
           ^^^^^^^^^^^^^^^^^^^^^^^^
AttributeError: 'ConcreteArray' object has no attribute 'data'

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "D:\MLEnv\Other stuff\issue.py", line 74, in <module>
    grads = nnx.grad(Loss)(state.model)
            ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "D:\MLEnv\Lib\site-packages\flax\nnx\nnx\graph.py", line 1158, in update_context_manager_wrapper
    return f(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^
  File "D:\MLEnv\Lib\site-packages\flax\nnx\nnx\transforms\autodiff.py", line 164, in grad_wrapper
    fn_out = gradded_fn(*pure_args)
             ^^^^^^^^^^^^^^^^^^^^^^
  File "D:\MLEnv\Lib\site-packages\flax\nnx\nnx\transforms\autodiff.py", line 86, in __call__
    out = self.f(*args)
          ^^^^^^^^^^^^^
  File "D:\MLEnv\Other stuff\issue.py", line 73, in <lambda>
    Loss=lambda model: compute_E(H, model)
                       ^^^^^^^^^^^^^^^^^^^
  File "D:\MLEnv\Other stuff\issue.py", line 65, in compute_E
    N=jax.experimental.sparse.bcoo_reduce_sum(N,axes=(0))
      ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "D:\MLEnv\Lib\site-packages\jax\experimental\sparse\bcoo.py", line 2125, in bcoo_reduce_sum
    mat.data, mat.indices, spinfo=mat._info, axes=axes)
    ^^^^^^^^
AttributeError: JVPTracer has no attribute data
--------------------
For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.

Steps to reproduce:

colab

import jax
import jax.numpy as jnp
from flax import nnx
import optax
from jax.experimental import sparse
from jax import vmap
class Model(nnx.Module):
    def __init__(self,rngs: nnx.Rngs):
        self.w = nnx.Param(jax.random.uniform(key=rngs.params(), shape=(3, 3),minval=0.0, maxval=0.1))
        self.b = nnx.Param(jnp.zeros((3)))
    def __call__(self, x):
        x= self.b+x @ self.w
        x = jax.nn.tanh(x)
        x=jnp.sum(x)
        return x

tx = optax.adam(0.01)
I=jnp.array([[1,0],[0,1]])
sig=jnp.array([[1,0],[0,-1]])
flip=jnp.array([[0,1],[1,0]])
h=0.5
H = jnp.kron(jnp.kron(sig,I),sig)+jnp.kron(jnp.kron(sig,sig),I)+jnp.kron(jnp.kron(I,sig),sig)-h*(jnp.kron(jnp.kron(I,I),flip)+jnp.kron(jnp.kron(I,flip),I)+jnp.kron(jnp.kron(flip,I),I))
H=sparse.BCOO.fromdense(H)
#remove the above line
model = Model(nnx.Rngs(0))

array_8x8 = jnp.array([[(i, j) for j in range(8)] for i in range(8)])
array_8 = jnp.arange(8)
#@sparse.sparsify
#remove the above line
def f(x,model):
    i,j=x
    P_i = model(jnp.array([-2 * ((i // 4)%2) + 1, -2 * ((i // 2)%2) + 1, -2 * ((i // 1)%2) + 1]))
    P_i=jnp.exp(P_i)
    P_j = model(jnp.array([-2 * ((j // 4)%2) + 1, -2 * ((j // 2)%2) + 1, -2 * ((j // 1)%2) + 1]))
    P_j=jnp.exp(P_j)
    P=jnp.multiply(P_i, P_j)
    E = H[i,j] * P
    return E
#@sparse.sparsify
#remove the above line
def normalize(i,model):
    P=model(jnp.array([-2 * ((i // 4)%2) + 1, -2 * ((i // 2)%2) + 1, -2 * ((i // 1)%2) + 1]))
    P=jnp.exp(P)
    N=jnp.multiply(P, P)
    return (N)

#@sparse.sparsify
#remove the above line
def compute_E(H, model):

    f_mapped = vmap(vmap(lambda x: f(x, model)))
    #f_sp=jax.experimental.sparse.sparsify(f_mapped)
#remove the above line
    N_mapped = vmap(lambda x: normalize(x, model))

#remove the above line
    E=f_mapped(array_8x8)
    E=jax.experimental.sparse.bcoo_reduce_sum(E,axes=(0,1))
    #replace by
    #E=jnp.sum(E)
    N=N_mapped(array_8)
    #N_sp=jax.experimental.sparse.sparsify(N_mapped)
    N=jax.experimental.sparse.bcoo_reduce_sum(N,axes=(0))
    #N=jnp.sum(N)
    return (E/N)

#model=sparse.sparsify(model)
#remove the above line
state = nnx.Optimizer(model, tx)
Loss=lambda model: compute_E(H, model)
grads = nnx.grad(Loss)(state.model)
state.update(grads)

The collab link also has the code with dense matrices which works.

#This is the code without any reference to sparsity. This works
import jax
import jax.numpy as jnp
from flax import nnx
import optax
from jax.experimental import sparse
from jax import vmap
class Model(nnx.Module):
    def __init__(self,rngs: nnx.Rngs):
        self.w = nnx.Param(jax.random.uniform(key=rngs.params(), shape=(3, 3),minval=0.0, maxval=0.1))
        self.b = nnx.Param(jnp.zeros((3)))
    def __call__(self, x):
        x= self.b+x @ self.w
        x = jax.nn.tanh(x)
        x=jnp.sum(x)
        return x

tx = optax.adam(0.01)
I=jnp.array([[1,0],[0,1]])
sig=jnp.array([[1,0],[0,-1]])
flip=jnp.array([[0,1],[1,0]])
h=0.5
H = jnp.kron(jnp.kron(sig,I),sig)+jnp.kron(jnp.kron(sig,sig),I)+jnp.kron(jnp.kron(I,sig),sig)-h*(jnp.kron(jnp.kron(I,I),flip)+jnp.kron(jnp.kron(I,flip),I)+jnp.kron(jnp.kron(flip,I),I))
model = Model(nnx.Rngs(0))

array_8x8 = jnp.array([[(i, j) for j in range(8)] for i in range(8)])
array_8 = jnp.arange(8)
def f(x,model):
    i,j=x
    P_i = model(jnp.array([-2 * ((i // 4)%2) + 1, -2 * ((i // 2)%2) + 1, -2 * ((i // 1)%2) + 1]))
    P_i=jnp.exp(P_i)
    P_j = model(jnp.array([-2 * ((j // 4)%2) + 1, -2 * ((j // 2)%2) + 1, -2 * ((j // 1)%2) + 1]))
    P_j=jnp.exp(P_j)
    P=jnp.multiply(P_i, P_j)
    E = H[i,j] * P
    return E
def normalize(i,model):
    P=model(jnp.array([-2 * ((i // 4)%2) + 1, -2 * ((i // 2)%2) + 1, -2 * ((i // 1)%2) + 1]))
    P=jnp.exp(P)
    N=jnp.multiply(P, P)
    return (N)

def compute_E(H, model):

    f_mapped = vmap(vmap(lambda x: f(x, model)))
    N_mapped = vmap(lambda x: normalize(x, model))
    E=f_mapped(array_8x8)
    E=jnp.sum(E)
    N=N_mapped(array_8)
    N=jnp.sum(N)
    return (E/N)

state = nnx.Optimizer(model, tx)
Loss=lambda model: compute_E(H, model)
grads = nnx.grad(Loss)(state.model)
state.update(grads)
print(Loss(model))
Prashant-Shekhar-Rao commented 1 week ago

My bad, I was calling jax.experimental.sparse.bcoo_reduce_sum on N which was not a sparse matrix. This is not a issue in jax or flax.