google / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
29.77k stars 2.72k forks source link

Indices of sparse arrays get corrupted by `jax.experimental.sparse.bcoo_reshape` if their _flattened_ values exceed int32_max #15896

Open geraschenko opened 1 year ago

geraschenko commented 1 year ago

Description

jax.experimental.sparse.bcoo_reshape implicitly flattens the array being reshaped (here) and then unflattens into the new shape. For large sparse arrays, the indices of entries in the sparse flattened array can exceed int32_max. In this case, they get clipped to int32_max before being converted back into multi-indices for the new shape. This means that the entries silently end up in the wrong place! This is unexpected when int32s are sufficient for both the source shape and target shape.

Using int64s for the indices solves the problem in my use case. Ideally reshapes should work correctly so long as int32s are good enough for both source and target shapes.

Code to reproduce:

import numpy as np
import jax
import jax.numpy as jnp
import jax.experimental.sparse as jsp

def test_if_trivial_reshape_works(n, d, index_dtype):
  # Place an entry in the bottom right-hand corner.
  entries = jnp.array([1.1])
  indices = jnp.array([[n-1, d - 1]]).astype(index_dtype)
  m = jsp.BCOO((entries, indices), shape=(n, d))

  # Reshape to the current shape. This should be a no-op.
  m_reshaped = jsp.bcoo_reshape(m, new_sizes=m.shape, dimensions=[])

  if (m_reshaped.indices == m.indices).all():
    print('All is well')
  else:
    print(f'Failure!\n{m.indices=}\n{m_reshaped.indices=}')

print(jax.__version__)  # 0.4.9

test_if_trivial_reshape_works(100, 300, 'int32')       # All is well
test_if_trivial_reshape_works(4000000, 1000, 'int32')  # Failure! [[3999999, 999]] vs [[3705032, 703]]

# Required for using int64 indices.
jax.config.update('jax_enable_x64', True)

test_if_trivial_reshape_works(100, 300, 'int64')       # All is well
test_if_trivial_reshape_works(100, 300, 'int32')       # All is well
test_if_trivial_reshape_works(4000000, 1000, 'int64')  # All is well
test_if_trivial_reshape_works(4000000, 1000, 'int32')  # Failure! [[3999999, 999]] vs [[3705032, 703]]

What jax/jaxlib version are you using?

jax v0.4.9

Which accelerator(s) are you using?

CPU

Additional system info

Python 3.10 on linux

NVIDIA GPU info

No response

geraschenko commented 1 year ago

@jakevdp