mit-han-lab / torchsparse

[MICRO'23, MLSys'22] TorchSparse: Efficient Training and Inference Framework for Sparse Convolution on GPUs.
https://torchsparse.mit.edu
MIT License
1.19k stars 138 forks source link

[BUG] Wrong SparseTensor.dense conversion #292

Closed hontrn9122 closed 7 months ago

hontrn9122 commented 7 months ago

Is there an existing issue for this?

Current Behavior

Given a dense torch tensor, for example:

test = torch.rand(3,3,2) The result test tensor: image

Then I try to convert this test tensor to a SparseTensor by the following code:

sparse_data = test.to_sparse()

sparse_indices = sparse_data.indices().transpose(1,0).contiguous() # add batch to indice, shape (Nx4) sparse_indices = torch.cat((torch.zeros(sparse_indices.size(0), 1), sparse_indices), dim=1)

sparse_feature = scan_data.values().view(-1,1)

sparse_data = SparseTensor(feats=sparse_feature.cuda(), coords=sparse_indices.cuda(), spatial_range=(3,3,2))

Then when I convert it back to its dense counterpart, the result is different from the original one:

print(sparse_data.dense().cpu().squeeze()) image

Is my converting code wrong or there are any problems with the SparseTensor.dense() function?

Expected Behavior

No response

Environment

- GCC:
- NVCC:
- PyTorch:2.1.2
- PyTorch CUDA: 11.8
- TorchSparse: 2.1.0

Anything else?

No response

ZzTodd22 commented 7 months ago

"Hello! Have you figured out how to convert a dense tensor to a sparse tensor? For instance, I'm looking to convert a dense tensor with the shape [B, T, C, H, W] to a sparse tensor. Any insights or solutions would be greatly appreciated. Thank you!"

hontrn9122 commented 7 months ago

"Hello! Have you figured out how to convert a dense tensor to a sparse tensor? For instance, I'm looking to convert a dense tensor with the shape [B, T, C, H, W] to a sparse tensor. Any insights or solutions would be greatly appreciated. Thank you!"

I use torch.tensor.to_sparse() to transform the dense tensor to the coo sparse tensor, then I use the indices, values, and size of the created sparse tensor to create Torchsparse sparse tensor ( SparseTensor(feats=values, coords=indices, spatial_range=size) ). Remember to transform the indices as specified in the Torchsparse docs and add the batch dimension if your original tensor does not have batch dim

ZzTodd22 commented 7 months ago

"Hello! Have you figured out how to convert a dense tensor to a sparse tensor? For instance, I'm looking to convert a dense tensor with the shape [B, T, C, H, W] to a sparse tensor. Any insights or solutions would be greatly appreciated. Thank you!"

I use torch.tensor.to_sparse() to transform the dense tensor to the coo sparse tensor, then I use the indices, values, and size of the created sparse tensor to create Torchsparse sparse tensor ( SparseTensor(feats=values, coords=indices, spatial_range=size) ). Remember to transform the indices as specified in the Torchsparse docs and add the batch dimension if your original tensor does not have batch dim Thank you very much, I wrote a function and it worked!


def from_dense(x: torch.Tensor):
"""create sparse tensor fron channel last dense tensor by to_sparse
x must be BTHWC tensor, channel last
"""
sparse_data = x.to_sparse(x.ndim-1)
spatial_shape = sparse_data.shape[:-1]
sparse_indices = sparse_data.indices().transpose(1, 0).contiguous().int()
sparse_feature = sparse_data.values()
return SparseTensor(feats=sparse_feature.cuda(), coords=sparse_indices.cuda(), spatial_range=spatial_shape)