Open muchanem opened 6 days ago
Seems related to the size calculation which would explain a bunch of communication overhead making the op slow, JITing it with a specified size speeds up computation significantly (it is still much much slower than numpy on CPU)
After a warmup loop, it's pretty quick: (tested on a TPU)
Description
When jnp.nonzero (or jnp.argwhere or sparse.BCOO.fromdense) is run on a sparse array and that array is sharded (this doesn't happen without sharding), the nonzero operation is extremely slow. This is a little odd to me since there shouldn't be much (or any) communication between devices for this op since it can be trivially split entirely between the devices. I'm actually unsure if the op ever completes. I've let it run for 10+ mins with no result.
System info (python version, jaxlib version, accelerator, etc.)