Open cherrywoods opened 4 days ago
I came up with a hotfix but I don't think what I implemented is the generally desirable behaviour. My fix is to insert these lines:
while i1 > 0 and new_sizes[i1 - 1] == 1:
i1 -= 1
while i2 > 0 and new_sizes[i2 - 1] == 1:
i2 -= 1
This moves all dimensions of size one to the next dimension kind (batch -> sparse, sparse -> dense). This works in my case, but I figure there could be other cases where this might be precisely the wrong thing to do?
Description
Reshape for sparse
BCOO
arrays fails if the target shape contains dimensions of size 1 and there is at least one dense dimension.Stack trace:
System info (python version, jaxlib version, accelerator, etc.)
Note: