mit-han-lab / torchsparse

[MICRO'23, MLSys'22] TorchSparse: Efficient Training and Inference Framework for Sparse Convolution on GPUs.
https://torchsparse.mit.edu
MIT License
1.22k stars 143 forks source link

[BUG] Calling backward with fp16 inputs results in RuntimeError: expected scalar type Float but found Half #331

Open luisa-waabi opened 1 month ago

luisa-waabi commented 1 month ago

Is there an existing issue for this?

Current Behavior

When running sparse conv's forward pass in fp16, everything seems to work fine. The backward pass throws a runtime error.

Minimal reproducible script using the provided example

from datetime import datetime
import numpy as np
import torch
import torch.cuda
import torch.nn as nn
import torch.optim
import torchsparse.nn as spnn
from torchsparse import SparseTensor
from torchsparse.utils.collate import sparse_collate_fn
from torchsparse.utils.quantize import sparse_quantize

def generate_random_point_cloud(size=100000, voxel_size=0.2):
    pc = np.random.randn(size, 4)
    pc[:, :3] = pc[:, :3] * 10
    labels = np.random.choice(10, size)
    coords, feats = pc[:, :3], pc
    coords -= np.min(coords, axis=0, keepdims=True)
    coords, indices = sparse_quantize(coords, voxel_size, return_index=True)
    coords = torch.tensor(coords, dtype=torch.int)
    feats = torch.tensor(feats[indices], dtype=torch.float)
    labels = torch.tensor(labels[indices], dtype=torch.long)
    input = SparseTensor(coords=coords, feats=feats)
    label = SparseTensor(coords=coords, feats=labels)
    feed_dict = {"input": input, "label": label}
    return feed_dict

def generate_batched_random_point_clouds(size=100000, voxel_size=0.2, batch_size=2):
    batch = []
    for _ in range(batch_size):
        batch.append(generate_random_point_cloud(size, voxel_size))
    return sparse_collate_fn(batch)

def dummy_train_3x3(device):
    model = nn.Sequential(
        spnn.Conv3d(4, 32, kernel_size=3, stride=1),
        spnn.Conv3d(32, 10, kernel_size=3, stride=1, transposed=True),
    ).to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
    criterion = nn.CrossEntropyLoss().to(device)

    print("Starting dummy_train_3x3...")
    time = datetime.now()
    for i in range(10):
        feed_dict = generate_batched_random_point_clouds()
        inputs = feed_dict["input"].to(device)
        inputs = inputs.half() # convert to fp16
        targets = feed_dict["label"].F.to(device).long()
        outputs = model(inputs)
        optimizer.zero_grad()
        loss = criterion(outputs.F, targets)
        loss.backward() # throws error
        optimizer.step()
        print('[step %d] loss = %f.'%(i, loss.item()))

    time = datetime.now() - time
    print("Finished dummy_train_3x3 in ", time)
dummy_train_3x3("cuda")

Stack trace

  File ".../torchsparse_py/torchsparse/nn/functional/conv/func/implicit_gemm.py", line 224, in backward
    grad_input, grad_weight = backward(
  File ".../torchsparse_py/torchsparse/nn/functional/conv/func/implicit_gemm.py", line 152, in backward
    torch.ops.torchsparse_ops.conv_backward_wgrad_implicit_gemm_sorted_cuda(
  File ".../pip_torch/site-packages/torch/_ops.py", line 854, in __call__
    return self_._op(*args, **(kwargs or {}))
RuntimeError: expected scalar type Float but found Half

Expected Behavior

I expect no errors to be thrown on the backward pass. I would also love to have compatibility with autocast, for which I think the ImplicitGEMMConvolutionFuntion functions need to be annotated with @custom_fwd and @custom_bwd.

Environment

TorchSparse: 2.1.0

Anything else?

No response