jax.experimental.sparse.bcoo_reshape implicitly flattens the array being reshaped (here) and then unflattens into the new shape. For large sparse arrays, the indices of entries in the sparse flattened array can exceed int32_max. In this case, they get clipped to int32_max before being converted back into multi-indices for the new shape. This means that the entries silently end up in the wrong place! This is unexpected when int32s are sufficient for both the source shape and target shape.
Using int64s for the indices solves the problem in my use case. Ideally reshapes should work correctly so long as int32s are good enough for both source and target shapes.
Code to reproduce:
import numpy as np
import jax
import jax.numpy as jnp
import jax.experimental.sparse as jsp
def test_if_trivial_reshape_works(n, d, index_dtype):
# Place an entry in the bottom right-hand corner.
entries = jnp.array([1.1])
indices = jnp.array([[n-1, d - 1]]).astype(index_dtype)
m = jsp.BCOO((entries, indices), shape=(n, d))
# Reshape to the current shape. This should be a no-op.
m_reshaped = jsp.bcoo_reshape(m, new_sizes=m.shape, dimensions=[])
if (m_reshaped.indices == m.indices).all():
print('All is well')
else:
print(f'Failure!\n{m.indices=}\n{m_reshaped.indices=}')
print(jax.__version__) # 0.4.9
test_if_trivial_reshape_works(100, 300, 'int32') # All is well
test_if_trivial_reshape_works(4000000, 1000, 'int32') # Failure! [[3999999, 999]] vs [[3705032, 703]]
# Required for using int64 indices.
jax.config.update('jax_enable_x64', True)
test_if_trivial_reshape_works(100, 300, 'int64') # All is well
test_if_trivial_reshape_works(100, 300, 'int32') # All is well
test_if_trivial_reshape_works(4000000, 1000, 'int64') # All is well
test_if_trivial_reshape_works(4000000, 1000, 'int32') # Failure! [[3999999, 999]] vs [[3705032, 703]]
Description
jax.experimental.sparse.bcoo_reshape
implicitly flattens the array being reshaped (here) and then unflattens into the new shape. For large sparse arrays, the indices of entries in the sparse flattened array can exceed int32_max. In this case, they get clipped to int32_max before being converted back into multi-indices for the new shape. This means that the entries silently end up in the wrong place! This is unexpected when int32s are sufficient for both the source shape and target shape.Using int64s for the indices solves the problem in my use case. Ideally reshapes should work correctly so long as int32s are good enough for both source and target shapes.
Code to reproduce:
What jax/jaxlib version are you using?
jax v0.4.9
Which accelerator(s) are you using?
CPU
Additional system info
Python 3.10 on linux
NVIDIA GPU info
No response