jax-ml / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
30.01k stars 2.75k forks source link

Differentiating BCOO sparse matrix constructor produces a dense matrix #14623

Closed aterenin closed 10 months ago

aterenin commented 1 year ago

Description

Unless I am mistaken, differentiating the BCOO constructor appears to produce dense matrices. Here's an MWE:

import jax 
import jax.numpy as jnp
from jax.experimental.sparse import BCOO

n = 1000
m = 777
p = 11
nonzeroes = jnp.stack((jnp.arange(0,m), jnp.arange(p,m+p)), axis=-1)

def quadratic(p):
    x1 = nonzeroes[...,0]
    x2 = nonzeroes[...,1]
    values = (x1 - x2)**2 / p
    matrix = BCOO((values, nonzeroes), shape=(n,n), unique_indices=True)
    return jnp.ones(n).T @ matrix @ jnp.ones(n)

graph = jax.xla_computation(jax.grad(quadratic))(1.0)

with open("t.dot", "w") as f:
    f.write(graph.as_hlo_dot_graph())

I've attached a computational graph: t.pdf. The problem is dot.28 (dark blue) followed by gather.57, which appears to assemble a 1000x1000 matrix, then grab the parts of it that are zero.

What jax/jaxlib version are you using?

0.4.4

Which accelerator(s) are you using?

N/A

Additional system info

N/A

NVIDIA GPU info

N/A

jakevdp commented 1 year ago

I don't think any dense matrices are produced on the JAX side of things. For example, here is the jaxpr of your differentiated function, which does not show any 1000,1000 variable being produced:

print(jax.make_jaxpr(jax.grad(quadratic))(1.0))
{ lambda a:i32[777,2]; b:f32[]. let
    c:i32[1] = broadcast_in_dim[broadcast_dimensions=() shape=(1,)] 0
    d:i32[777] = gather[
      dimension_numbers=GatherDimensionNumbers(offset_dims=(0,), collapsed_slice_dims=(1,), start_index_map=(1,))
      fill_value=None
      indices_are_sorted=True
      mode=GatherScatterMode.PROMISE_IN_BOUNDS
      slice_sizes=(777, 1)
      unique_indices=True
    ] a c
    e:i32[1] = broadcast_in_dim[broadcast_dimensions=() shape=(1,)] 1
    f:i32[777] = gather[
      dimension_numbers=GatherDimensionNumbers(offset_dims=(0,), collapsed_slice_dims=(1,), start_index_map=(1,))
      fill_value=None
      indices_are_sorted=True
      mode=GatherScatterMode.PROMISE_IN_BOUNDS
      slice_sizes=(777, 1)
      unique_indices=True
    ] a e
    g:i32[777] = sub d f
    h:i32[777] = integer_pow[y=2] g
    i:f32[777] = convert_element_type[new_dtype=float32 weak_type=True] h
    j:f32[777] = div i b
    k:f32[] = integer_pow[y=-2] b
    l:f32[1000] = broadcast_in_dim[broadcast_dimensions=() shape=(1000,)] 1.0
    m:f32[777] = convert_element_type[new_dtype=float32 weak_type=False] j
    n:f32[1000] = bcoo_dot_general[
      dimension_numbers=(((0,), (0,)), ((), ()))
      lhs_spinfo=BCOOInfo(shape=(1000, 1000), indices_sorted=False, unique_indices=True)
    ] m a l
    o:f32[1000] = broadcast_in_dim[broadcast_dimensions=() shape=(1000,)] 1.0
    _:f32[] = dot_general[
      dimension_numbers=(((0,), (0,)), ((), ()))
      precision=None
      preferred_element_type=None
    ] n o
    p:f32[1000] = dot_general[
      dimension_numbers=(((), ()), ((), ()))
      precision=None
      preferred_element_type=None
    ] 1.0 o
    q:f32[777] = broadcast_in_dim[broadcast_dimensions=() shape=(777,)] 1.0
    _:f32[777] r:i32[777,2] = bcoo_transpose[
      permutation=(1, 0)
      spinfo=BCOOInfo(shape=(1, 1), indices_sorted=False, unique_indices=False)
    ] q a
    s:f32[777] = bcoo_dot_general_sampled[
      dimension_numbers=(((), ()), ((), ()))
    ] p l r
    t:f32[777] _:i32[777,2] = bcoo_transpose[
      permutation=(1, 0)
      spinfo=BCOOInfo(shape=(1, 1), indices_sorted=False, unique_indices=False)
    ] s r
    u:f32[777] = convert_element_type[new_dtype=float32 weak_type=True] t
    v:f32[777] = mul u k
    w:f32[777] = mul v i
    x:f32[] = reduce_sum[axes=(0,)] w
    y:f32[] = neg x
  in (y,) }
jakevdp commented 1 year ago

That said, what might be happening is that you're hitting this TODO: https://github.com/google/jax/blob/7e001d842e1cbfdbf991a4bd9b236012cc40fba4/jax/experimental/sparse/bcoo.py#L1091-L1092

dot_general_sampled is used in the transpose rule of dot_general, which requires extracting sparse indices from a dense dot product. There are probably more efficient ways to implement this, thus the TODO

jakevdp commented 1 year ago

One way around this would be to use forward-mode autodiff, which does not use the transpose rule.

aterenin commented 1 year ago

Thanks for your quick response! I'm not familiar enough with the internals to know whether this is happening in JAX or somewhere lower-level. However, here's an attempted workaround, via a custom reverse rule.

from functools import partial
from jax import custom_vjp

@partial(custom_vjp, nondiff_argnums=(2,3))
def sparse_matrix_product(target, values, nonzeroes, shape):
    return BCOO((values, nonzeroes), shape = shape, unique_indices = True) @ target

def sparse_matrix_product_fwd(target, values, nonzeroes, shape):
    out = sparse_matrix_product(target, values, nonzeroes, shape)
    carry = (target, values)
    return out, carry

def sparse_matrix_product_rev(nonzeroes, shape, carry, cotangents):
    (target, values) = carry
    target_cotangents = BCOO((values, nonzeroes), shape = shape, unique_indices = True).T @ cotangents

    x1 = nonzeroes[...,0]
    x2 = nonzeroes[...,1]

    sparse_cotangents = cotangents[...,x1,:]
    sparse_targets = target[...,x2,:]
    values_cotangents = jnp.sum(sparse_cotangents * sparse_targets, axis=-1)

    return (target_cotangents, values_cotangents)

sparse_matrix_product.defvjp(sparse_matrix_product_fwd, sparse_matrix_product_rev)

import jax 
import jax.numpy as jnp
from jax.experimental.sparse import BCOO

n = 1000
m = 777
p = 11
nonzeroes = jnp.stack((jnp.arange(0,m), jnp.arange(p,m+p)), axis=-1)

def quadratic(p):
    x1 = nonzeroes[...,0]
    x2 = nonzeroes[...,1]
    values = (x1 - x2)**2 / p
    matrix = BCOO((values, nonzeroes), shape=(n,n), unique_indices=True)
    return jnp.ones(n).T @ matrix @ jnp.ones(n)

def fixed_quadratic(p):
    x1 = nonzeroes[...,0]
    x2 = nonzeroes[...,1]
    values = (x1 - x2)**2 / p
    return (jnp.ones(n).T @ sparse_matrix_product(jnp.expand_dims(jnp.ones(n),-1), values, nonzeroes, (n,n))).squeeze()

(jax.grad(quadratic)(1.0)), jax.grad(fixed_quadratic)(1.0)

graph = jax.xla_computation(jax.grad(fixed_quadratic))(1.0)

with open("t.dot", "w") as f:
    f.write(graph.as_hlo_dot_graph())

Here's the corresponding computational graph: t.pdf. This appears so far to avoid dense matrix assembly, though I have not done any testing yet beyond an interactive notebook and thus there may be mistakes.

jakevdp commented 1 year ago

Very nice! Unfortunately it's not that easy if you have to handle all the possibilities offered by dot_general, which the BCOO code has to account for. We could probably special-case some situations and do something like this in simpler cases.

jakevdp commented 1 year ago

14630 should fix the specific case you bring up here, but more general cases are a bit harder to get right...

aterenin commented 1 year ago

Thanks so much for the extremely quick response - I definitely did not expect someone to even see the issue the same day I posted it, much less fix it!

I tried it out, and it appears to fix the issue for the original MWE with matrix-vector products, but not for matrix-matrix products, which is actually what occurs in my original codebase where I discovered this issue before reducing this to the MWE. For example, replacing the original MWE with

def quadratic(p):
    x1 = nonzeroes[...,0]
    x2 = nonzeroes[...,1]
    values = (x1 - x2)**2 / p
    matrix = BCOO((values, nonzeroes), shape=(n,n), unique_indices=True)
    return jnp.ones(n).T @ (matrix @ jnp.ones((n,p)))

appears to hit the slow version. My custom_vjp above also works for this case.

jakevdp commented 1 year ago

Reopening - I'm taking a look at the SpMM issue

aterenin commented 1 year ago

Meanwhile: if anyone needs this before a fix lands in the next version of JAX, here's my updated workaround for this and #14642. You can replace BCOO with bug_workaround_BCOO below, as long as you provide it with arguments for which indices_sorted and unique_indices are both true:

jax.config.update("jax_bcoo_cusparse_lowering", True)

class bug_workaround_BCOO():
    def __init__(self, args: Tuple[Array, Array], *, shape: Sequence[int], indices_sorted: bool, unique_indices: bool):
        assert indices_sorted
        assert unique_indices
        (self.values, self.nonzeroes) = args
        self.shape = shape

    def __matmul__(
        self, other: Float[Array, "N M"]
    ) -> Float[Array, "N M"]:
        return sparse_matrix_product(other, self.values, self.nonzeroes, self.shape)

@partial(custom_vjp, nondiff_argnums=(3,))
@partial(jit, static_argnums=3)
def sparse_matrix_product(target: Float[Array, "M P"], values: Float[Array, "Z"], nonzeroes: Float[Array, "Z 2"], shape: Sequence[int]):
    sparse_dim = len(shape) - 1
    dense_dim = target.ndim - 2 if target.ndim >= 2 else 0
    sparse_batch_dim = tuple(i for i in range(len(shape) - 2))
    dense_batch_dim = tuple(i for i in range(target.ndim - 2))
    dimension_numbers = (((sparse_dim,),(dense_dim,)),(sparse_batch_dim,dense_batch_dim))
    info = SparseInfo(shape, True, True)
    return _bcoo_dot_general(values, nonzeroes, target, dimension_numbers = dimension_numbers, lhs_spinfo = info)

def sparse_matrix_product_fwd(target: Float[Array, "M P"], values: Float[Array, "Z"], nonzeroes: Float[Array, "Z 2"], shape: Sequence[int]):
    out = sparse_matrix_product(target, values, nonzeroes, shape)
    carry = (target, values, nonzeroes)
    return out, carry

@partial(jit, static_argnums=0)
def sparse_matrix_product_rev(shape: Sequence[int], carry: Tuple[Float[Array, "M P"], Float[Array, "Z 2"], Float[Array, "Z"]], cotangents: Float[Array, "N P"]):
    (target, values, nonzeroes) = carry
    sparse_dim = len(shape) - 2
    dense_dim = cotangents.ndim - 1 if cotangents.ndim >= 2 else 0
    sparse_batch_dim = tuple(i for i in range(len(shape) - 2))
    dense_batch_dim = tuple(i for i in range(target.ndim - 2))
    dimension_numbers = (((dense_dim,),(sparse_dim,)),(dense_batch_dim,sparse_batch_dim))
    info = SparseInfo(shape, True, True)
    target_cotangents = _bcoo_rdot_general(cotangents.T, values, nonzeroes, dimension_numbers = dimension_numbers, rhs_spinfo = info).T

    x1 = nonzeroes[...,0]
    x2 = nonzeroes[...,1]

    sparse_cotangents = cotangents[...,x1,:]
    sparse_targets = target[...,x2,:]
    values_cotangents = (sparse_cotangents * sparse_targets).sum(axis=-1)

    return (target_cotangents, values_cotangents, None)

sparse_matrix_product.defvjp(sparse_matrix_product_fwd, sparse_matrix_product_rev)
jakevdp commented 10 months ago

I think the issue reported here is fixed by #14630