Open AlexanderMath opened 1 year ago
Found same issue with jax.cumprod
when trying to use log-tricks as temporary solution.
Here's a hacky temporary solution. It uses matrix multiplication to compute jnp.cumsum
of 2**7
chunks, and then subsequently adds the correct offsets. Use with caution. >90% of time is spent adding the subsequent offsets.
import jax
import jax.numpy as jnp
def matmul_cumsum_jax(arr):
return jnp.tril(jnp.ones((len(arr), len(arr)))) @ arr
def cumsum_jax(arr):
chunk_size = 2**7
original_shape = arr.shape
padding = chunk_size - (len(arr) % chunk_size) if len(arr) % chunk_size != 0 else 0
arr = jnp.pad(arr, (0, padding))
num_chunks = -(-len(arr) // chunk_size)
chunks = arr.T.reshape(num_chunks, chunk_size)
chunks = jax.vmap(matmul_cumsum_jax)(chunks)
offset = 0
offsets = [offset]
for i, chunk in enumerate(chunks):
offset += chunk[-1]
offsets.append(offset)
chunks = jax.vmap(jax.lax.add, in_axes=(0,0))(chunks, jnp.array(offsets[:-1]))
return jnp.concatenate(chunks).reshape(-1)[:original_shape[0]]
arange = np.arange(2**14)
arange = np.concatenate((np.zeros(1), np.diff(arange))).astype(np.int32)
true_indxs = np.cumsum(arange)
us_indxs = np.asarray(jax.jit(cumsum_jax, backend="ipu")(arange)).astype(np.int32)
print(true_indxs[::127])
print(us_indxs[::127])
print(np.max(np.abs(true_indxs - us_indxs)))
print(np.all(true_indxs==us_indxs))
Description
Reproducer
Output
Note:
Meta comment. The reproducer took 2 hours to make because
jnp.cumsum
was used inside ~400 lines of code, and I wrongly assumedjnp.cumsum
was unlikely to cause segment fault compared to: tesseleate-ipu, C code, usage of uint in C code, poplar simulation of uint64 in C code, passing from python to C code, index computations, ... . Would it be a lot of work to add automated testing on these basic (np, jnp) functions?What jax/jaxlib version are you using?
0.3.16
Which accelerator(s) are you using?
IPU MK2
Additional System Info
No response