Closed aterenin closed 10 months ago
I don't think any dense matrices are produced on the JAX side of things. For example, here is the jaxpr of your differentiated function, which does not show any 1000,1000
variable being produced:
print(jax.make_jaxpr(jax.grad(quadratic))(1.0))
{ lambda a:i32[777,2]; b:f32[]. let
c:i32[1] = broadcast_in_dim[broadcast_dimensions=() shape=(1,)] 0
d:i32[777] = gather[
dimension_numbers=GatherDimensionNumbers(offset_dims=(0,), collapsed_slice_dims=(1,), start_index_map=(1,))
fill_value=None
indices_are_sorted=True
mode=GatherScatterMode.PROMISE_IN_BOUNDS
slice_sizes=(777, 1)
unique_indices=True
] a c
e:i32[1] = broadcast_in_dim[broadcast_dimensions=() shape=(1,)] 1
f:i32[777] = gather[
dimension_numbers=GatherDimensionNumbers(offset_dims=(0,), collapsed_slice_dims=(1,), start_index_map=(1,))
fill_value=None
indices_are_sorted=True
mode=GatherScatterMode.PROMISE_IN_BOUNDS
slice_sizes=(777, 1)
unique_indices=True
] a e
g:i32[777] = sub d f
h:i32[777] = integer_pow[y=2] g
i:f32[777] = convert_element_type[new_dtype=float32 weak_type=True] h
j:f32[777] = div i b
k:f32[] = integer_pow[y=-2] b
l:f32[1000] = broadcast_in_dim[broadcast_dimensions=() shape=(1000,)] 1.0
m:f32[777] = convert_element_type[new_dtype=float32 weak_type=False] j
n:f32[1000] = bcoo_dot_general[
dimension_numbers=(((0,), (0,)), ((), ()))
lhs_spinfo=BCOOInfo(shape=(1000, 1000), indices_sorted=False, unique_indices=True)
] m a l
o:f32[1000] = broadcast_in_dim[broadcast_dimensions=() shape=(1000,)] 1.0
_:f32[] = dot_general[
dimension_numbers=(((0,), (0,)), ((), ()))
precision=None
preferred_element_type=None
] n o
p:f32[1000] = dot_general[
dimension_numbers=(((), ()), ((), ()))
precision=None
preferred_element_type=None
] 1.0 o
q:f32[777] = broadcast_in_dim[broadcast_dimensions=() shape=(777,)] 1.0
_:f32[777] r:i32[777,2] = bcoo_transpose[
permutation=(1, 0)
spinfo=BCOOInfo(shape=(1, 1), indices_sorted=False, unique_indices=False)
] q a
s:f32[777] = bcoo_dot_general_sampled[
dimension_numbers=(((), ()), ((), ()))
] p l r
t:f32[777] _:i32[777,2] = bcoo_transpose[
permutation=(1, 0)
spinfo=BCOOInfo(shape=(1, 1), indices_sorted=False, unique_indices=False)
] s r
u:f32[777] = convert_element_type[new_dtype=float32 weak_type=True] t
v:f32[777] = mul u k
w:f32[777] = mul v i
x:f32[] = reduce_sum[axes=(0,)] w
y:f32[] = neg x
in (y,) }
That said, what might be happening is that you're hitting this TODO: https://github.com/google/jax/blob/7e001d842e1cbfdbf991a4bd9b236012cc40fba4/jax/experimental/sparse/bcoo.py#L1091-L1092
dot_general_sampled
is used in the transpose rule of dot_general
, which requires extracting sparse indices from a dense dot product. There are probably more efficient ways to implement this, thus the TODO
One way around this would be to use forward-mode autodiff, which does not use the transpose rule.
Thanks for your quick response! I'm not familiar enough with the internals to know whether this is happening in JAX or somewhere lower-level. However, here's an attempted workaround, via a custom reverse rule.
from functools import partial
from jax import custom_vjp
@partial(custom_vjp, nondiff_argnums=(2,3))
def sparse_matrix_product(target, values, nonzeroes, shape):
return BCOO((values, nonzeroes), shape = shape, unique_indices = True) @ target
def sparse_matrix_product_fwd(target, values, nonzeroes, shape):
out = sparse_matrix_product(target, values, nonzeroes, shape)
carry = (target, values)
return out, carry
def sparse_matrix_product_rev(nonzeroes, shape, carry, cotangents):
(target, values) = carry
target_cotangents = BCOO((values, nonzeroes), shape = shape, unique_indices = True).T @ cotangents
x1 = nonzeroes[...,0]
x2 = nonzeroes[...,1]
sparse_cotangents = cotangents[...,x1,:]
sparse_targets = target[...,x2,:]
values_cotangents = jnp.sum(sparse_cotangents * sparse_targets, axis=-1)
return (target_cotangents, values_cotangents)
sparse_matrix_product.defvjp(sparse_matrix_product_fwd, sparse_matrix_product_rev)
import jax
import jax.numpy as jnp
from jax.experimental.sparse import BCOO
n = 1000
m = 777
p = 11
nonzeroes = jnp.stack((jnp.arange(0,m), jnp.arange(p,m+p)), axis=-1)
def quadratic(p):
x1 = nonzeroes[...,0]
x2 = nonzeroes[...,1]
values = (x1 - x2)**2 / p
matrix = BCOO((values, nonzeroes), shape=(n,n), unique_indices=True)
return jnp.ones(n).T @ matrix @ jnp.ones(n)
def fixed_quadratic(p):
x1 = nonzeroes[...,0]
x2 = nonzeroes[...,1]
values = (x1 - x2)**2 / p
return (jnp.ones(n).T @ sparse_matrix_product(jnp.expand_dims(jnp.ones(n),-1), values, nonzeroes, (n,n))).squeeze()
(jax.grad(quadratic)(1.0)), jax.grad(fixed_quadratic)(1.0)
graph = jax.xla_computation(jax.grad(fixed_quadratic))(1.0)
with open("t.dot", "w") as f:
f.write(graph.as_hlo_dot_graph())
Here's the corresponding computational graph: t.pdf. This appears so far to avoid dense matrix assembly, though I have not done any testing yet beyond an interactive notebook and thus there may be mistakes.
Very nice! Unfortunately it's not that easy if you have to handle all the possibilities offered by dot_general
, which the BCOO code has to account for. We could probably special-case some situations and do something like this in simpler cases.
Thanks so much for the extremely quick response - I definitely did not expect someone to even see the issue the same day I posted it, much less fix it!
I tried it out, and it appears to fix the issue for the original MWE with matrix-vector products, but not for matrix-matrix products, which is actually what occurs in my original codebase where I discovered this issue before reducing this to the MWE. For example, replacing the original MWE with
def quadratic(p):
x1 = nonzeroes[...,0]
x2 = nonzeroes[...,1]
values = (x1 - x2)**2 / p
matrix = BCOO((values, nonzeroes), shape=(n,n), unique_indices=True)
return jnp.ones(n).T @ (matrix @ jnp.ones((n,p)))
appears to hit the slow version. My custom_vjp
above also works for this case.
Reopening - I'm taking a look at the SpMM issue
Meanwhile: if anyone needs this before a fix lands in the next version of JAX, here's my updated workaround for this and #14642. You can replace BCOO
with bug_workaround_BCOO
below, as long as you provide it with arguments for which indices_sorted
and unique_indices
are both true:
jax.config.update("jax_bcoo_cusparse_lowering", True)
class bug_workaround_BCOO():
def __init__(self, args: Tuple[Array, Array], *, shape: Sequence[int], indices_sorted: bool, unique_indices: bool):
assert indices_sorted
assert unique_indices
(self.values, self.nonzeroes) = args
self.shape = shape
def __matmul__(
self, other: Float[Array, "N M"]
) -> Float[Array, "N M"]:
return sparse_matrix_product(other, self.values, self.nonzeroes, self.shape)
@partial(custom_vjp, nondiff_argnums=(3,))
@partial(jit, static_argnums=3)
def sparse_matrix_product(target: Float[Array, "M P"], values: Float[Array, "Z"], nonzeroes: Float[Array, "Z 2"], shape: Sequence[int]):
sparse_dim = len(shape) - 1
dense_dim = target.ndim - 2 if target.ndim >= 2 else 0
sparse_batch_dim = tuple(i for i in range(len(shape) - 2))
dense_batch_dim = tuple(i for i in range(target.ndim - 2))
dimension_numbers = (((sparse_dim,),(dense_dim,)),(sparse_batch_dim,dense_batch_dim))
info = SparseInfo(shape, True, True)
return _bcoo_dot_general(values, nonzeroes, target, dimension_numbers = dimension_numbers, lhs_spinfo = info)
def sparse_matrix_product_fwd(target: Float[Array, "M P"], values: Float[Array, "Z"], nonzeroes: Float[Array, "Z 2"], shape: Sequence[int]):
out = sparse_matrix_product(target, values, nonzeroes, shape)
carry = (target, values, nonzeroes)
return out, carry
@partial(jit, static_argnums=0)
def sparse_matrix_product_rev(shape: Sequence[int], carry: Tuple[Float[Array, "M P"], Float[Array, "Z 2"], Float[Array, "Z"]], cotangents: Float[Array, "N P"]):
(target, values, nonzeroes) = carry
sparse_dim = len(shape) - 2
dense_dim = cotangents.ndim - 1 if cotangents.ndim >= 2 else 0
sparse_batch_dim = tuple(i for i in range(len(shape) - 2))
dense_batch_dim = tuple(i for i in range(target.ndim - 2))
dimension_numbers = (((dense_dim,),(sparse_dim,)),(dense_batch_dim,sparse_batch_dim))
info = SparseInfo(shape, True, True)
target_cotangents = _bcoo_rdot_general(cotangents.T, values, nonzeroes, dimension_numbers = dimension_numbers, rhs_spinfo = info).T
x1 = nonzeroes[...,0]
x2 = nonzeroes[...,1]
sparse_cotangents = cotangents[...,x1,:]
sparse_targets = target[...,x2,:]
values_cotangents = (sparse_cotangents * sparse_targets).sum(axis=-1)
return (target_cotangents, values_cotangents, None)
sparse_matrix_product.defvjp(sparse_matrix_product_fwd, sparse_matrix_product_rev)
I think the issue reported here is fixed by #14630
Description
Unless I am mistaken, differentiating the BCOO constructor appears to produce dense matrices. Here's an MWE:
I've attached a computational graph: t.pdf. The problem is
dot.28
(dark blue) followed bygather.57
, which appears to assemble a 1000x1000 matrix, then grab the parts of it that are zero.What jax/jaxlib version are you using?
0.4.4
Which accelerator(s) are you using?
N/A
Additional system info
N/A
NVIDIA GPU info
N/A