willow-ahrens / Finch.jl

Sparse tensors in Julia and more! Datastructure-driven array programing language.
http://willowahrens.io/Finch.jl/
MIT License
158 stars 15 forks source link

Question regarding 3d swizzle array #387

Closed mtsokol closed 7 months ago

mtsokol commented 7 months ago

Hi @willow-ahrens,

I've got a question about permuting 3d swizzle array. Here's a short code where permuting a dense array differs from a Finch array, but resulting shapes should be equal (for other combinations of dimensions it worked as expected):

using Finch

A = zeros(2, 4, 3)
A[1,:,:] = [0.0 0.0 4.4; 1.1 0.0 0.0; 0.0 0.0 0.0; 3.3 0.0 0.0]
A[2,:,:] = [1.0 0.0 0.0; 0.0 0.0 0.0; 0.0 1.0 0.0; 3.3 0.0 0.0]

permutation = (3, 1, 2)

new_shape_1 = size(permutedims(A, permutation))

t = Tensor(Dense(SparseList(SparseList(Element(0.0)))), A)
st = swizzle(t, permutation...)
# materialize swizzle
new_shape_2 = size(Tensor(Dense(SparseList(SparseList(Element(0.0)))), st))

print(new_shape_1, new_shape_2)  # they don't equal

Is my reasoning correct?

willow-ahrens commented 7 months ago

ah, maybe somewhere in Finch we're missing an invperm

willow-ahrens commented 7 months ago

So, one related issue here is that copyto! in Finch isn't returning a swizzle, it actually pops the swizzle off of the tensor: https://github.com/willow-ahrens/Finch.jl/blob/0a536c7f43ef76ffd55ca0a927e654466e1efda8/src/transforms/wrapperize.jl#L181-L188

But I think that problem isn't the same as what we see here. I think this may just be that the rewrites for Swizzle are broken.