facebookresearch / pytorch3d

PyTorch3D is FAIR's library of reusable components for deep learning with 3D data
https://pytorch3d.org/
Other
8.81k stars 1.32k forks source link

"RuntimeError: Cannot access data pointer of Tensor that doesn't have storage" when using autograd with knn #1533

Closed TimoRST closed 1 year ago

TimoRST commented 1 year ago

🐛 Bugs / Unexpected behaviors

I'm trying to combine your libraries pytorch3d and theseus. I'm using a auto grad function as objective in theseus, in which I need to get the K nearest neighbors of a point cloud to another to perform some modified GICP. When theseus calls auto grad on the objective I get an RuntimeError in knn.py of pytorch3d at line 69 "idx, dists = _C.knn_points_idx(p1, p2, lengths1, lengths2, norm, K, version)": "RuntimeError: Cannot access data pointer of Tensor that doesn't have storage". Previous calls work without problems. The error only occurs when calling auto grad on the function. At this time p1 und p2 are "BatchedTensors". E.g.: p1 = {Tensor: (1, 1024, 3)} BatchedTensor(lvl=1, bdim=0, value=\n tensor([[[[ 4.0704, -1.7715, -1.6834],\n [ 44.4077, 61.2256, 0.9029]...7, -1.8939],\n [-13.8112, -16.6522, 0.1097],\n [-12.5197, -11.3855, 0.1289]]]], device='cuda:0')\n) p2 = {Tensor: (1, 1024, 3)} BatchedTensor(lvl=1, bdim=0, value=\n tensor([[[[-1.4440e+01, -8.8393e+00, -3.6783e-02],\n [ 4.3913e+01, 6.1490... [ 3.0349e+01, -6.0585e+01, -6.5257e-01],\n [ 1.6265e+01, -5.2071e+00, 5.0570e-01]]]], device='cuda:0')\n)

I modified the class a little bit to work with pytorch2 (basically moved the ctx calls to the "setup_context" function and set "generate_vmap_rule = True":

class _knn_points(Function):
    """
    Torch autograd Function wrapper for KNN C++/CUDA implementations.
    """
    generate_vmap_rule = True
    @staticmethod
    # pyre-fixme[14]: `forward` overrides method defined in `Function` inconsistently.
    def forward(
        p1,
        p2,
        lengths1,
        lengths2,
        K,
        version,
        norm: int = 2,
        return_sorted: bool = True,
    ):
        """
        K-Nearest neighbors on point clouds.

        Args:
            p1: Tensor of shape (N, P1, D) giving a batch of N point clouds, each
                containing up to P1 points of dimension D.
            p2: Tensor of shape (N, P2, D) giving a batch of N point clouds, each
                containing up to P2 points of dimension D.
            lengths1: LongTensor of shape (N,) of values in the range [0, P1], giving the
                length of each pointcloud in p1. Or None to indicate that every cloud has
                length P1.
            lengths2: LongTensor of shape (N,) of values in the range [0, P2], giving the
                length of each pointcloud in p2. Or None to indicate that every cloud has
                length P2.
            K: Integer giving the number of nearest neighbors to return.
            version: Which KNN implementation to use in the backend. If version=-1,
                the correct implementation is selected based on the shapes of the inputs.
            norm: (int) indicating the norm. Only supports 1 (for L1) and 2 (for L2).
            return_sorted: (bool) whether to return the nearest neighbors sorted in
                ascending order of distance.

        Returns:
            p1_dists: Tensor of shape (N, P1, K) giving the squared distances to
                the nearest neighbors. This is padded with zeros both where a cloud in p2
                has fewer than K points and where a cloud in p1 has fewer than P1 points.

            p1_idx: LongTensor of shape (N, P1, K) giving the indices of the
                K nearest neighbors from points in p1 to points in p2.
                Concretely, if `p1_idx[n, i, k] = j` then `p2[n, j]` is the k-th nearest
                neighbors to `p1[n, i]` in `p2[n]`. This is padded with zeros both where a cloud
                in p2 has fewer than K points and where a cloud in p1 has fewer than P1 points.
        """
        if not ((norm == 1) or (norm == 2)):
            raise ValueError("Support for 1 or 2 norm.")

        idx, dists = _C.knn_points_idx(p1, p2, lengths1, lengths2, norm, K, version)

        # sort KNN in ascending order if K > 1
        if K > 1 and return_sorted:
            if lengths2.min() < K:
                P1 = p1.shape[1]
                mask = lengths2[:, None] <= torch.arange(K, device=dists.device)[None]
                # mask has shape [N, K], true where dists irrelevant
                mask = mask[:, None].expand(-1, P1, -1)
                # mask has shape [N, P1, K], true where dists irrelevant
                dists[mask] = float("inf")
                dists, sort_idx = dists.sort(dim=2)
                dists[mask] = 0
            else:
                dists, sort_idx = dists.sort(dim=2)
            idx = idx.gather(2, sort_idx)

        # ctx.save_for_backward(p1, p2, lengths1, lengths2, idx)
        # ctx.mark_non_differentiable(idx)
        # ctx.norm = norm
        return dists, idx

    @staticmethod
    def setup_context(ctx: Any, inputs: Tuple[Any], output: Any) -> Any:
        p1, p2, lengths1, lengths2, K, version, norm, return_sorted = inputs
        dists, idx = output
        ctx.save_for_backward(p1, p2, lengths1, lengths2, idx)
        ctx.mark_non_differentiable(idx)
        ctx.norm = norm

    @staticmethod
    @once_differentiable
    def backward(ctx, grad_dists, grad_idx):
        p1, p2, lengths1, lengths2, idx = ctx.saved_tensors
        norm = ctx.norm
        # TODO(gkioxari) Change cast to floats once we add support for doubles.
        if not (grad_dists.dtype == torch.float32):
            grad_dists = grad_dists.float()
        if not (p1.dtype == torch.float32):
            p1 = p1.float()
        if not (p2.dtype == torch.float32):
            p2 = p2.float()
        grad_p1, grad_p2 = _C.knn_points_backward(
            p1, p2, lengths1, lengths2, idx, norm, grad_dists
        )
        return grad_p1, grad_p2, None, None, None, None, None, None

Torchversion: 2.0.0 with Cuda 11.8

Any idea how to deal with that? I've got it running with an knn implementation from pointnet2, but yours is much more efficient.

TimoRST commented 1 year ago

Just found out that you can call .storage() on a Tensor and BatchedTensor doesnt seem to support that: NotImplementedError('Cannot access storage of BatchedTensorImpl')

I currently do not understand why this makes a problem with your knn implementation while it works with standard python code.

bottler commented 1 year ago

The main forwards knn function is here and presumably that is what is failing. Are you saying it works if no gradients are required, but fails if they are?

I don't know how functorch/vmap/BatchedTensor/etc work, and our C++ and cuda implementations have not been tested with them.

TimoRST commented 1 year ago

It also works when gradients are required, but it doesn't do with this jacrev, vmap, BatchedTensor thing. I also have no experience of this type of use, but its required for Theseus to do the optimizer step, or more

I was hoping to find someone who knows what's going on, as I don't even understand the problem that's occurring here. Maybe I should ask again about this error in the Theseus repo.

bottler commented 1 year ago

I understand the general problem but not the details or how to fix it properly. As a workaround, Theseus might have, or could be given, a special slow mode where it doesn't rely on functorch. I think it would be worth opening an issue for Theseus.

bottler commented 1 year ago

I think it would be possible to add a custom vmap to _knn_points - see https://pytorch.org/docs/stable/notes/extending.func.html#defining-the-vmap-staticmethod to make this work properly, and wouldn't need any new cuda code. It would need to combine the vmap dimension with the batch dimension, do the normal calculation, and then split the dimensions again.

TimoRST commented 1 year ago

Thanks. I'm going to try it out. Maybe the auto generation with "generate_vmap_rule = True" makes some faults there.

TimoRST commented 1 year ago

Thanks. Was really that simple. Wrote a short vmap function for vmapping over p1 and p2 and now it works:

    @staticmethod
    def vmap(info,
             in_dims,
             p1,
             p2,
             lengths1,
             lengths2,
             K,
             version,
             norm,
             return_sorted):

        # We need to stack the vmap dimension with the batch dimension
        if in_dims[0] is not None:
            # We need to stack the vmap dimension with the batch dimension
            p1_V, p1_B, p1_N, p1_C = p1.shape
            p1 = einops.rearrange(p1, 'V B N C-> (V B) N C')

        if in_dims[1] is not None:
            p2_V, p2_B, p2_N, p2_C = p2.shape
            p2 = einops.rearrange(p2, 'V B N C-> (V B) N C')

        # with stacked point clouds, we can just call the forward method
        dists, idx = _knn_points.forward(p1, p2, lengths1, lengths2, K, version, norm, return_sorted)

        # now reshape the output to have the vmap dimension
        if in_dims[0] is not None:
            dists = einops.rearrange(dists, '(V B) N K -> V B N K', V=p1_V)
            idx = einops.rearrange(idx, '(V B) N K -> V B N K', V=p1_V)

        return (dists, idx), (0, 0)
xhsonny commented 6 months ago

@TimoRST hey, I am running into a similar issue where I am trying to run jacrev or jacfwd on an objective that uses a pytorch3d class _PointFaceDistance. I added vmap method into the class but jacrev and jacfwd are both failing. They throw exceptions before hitting vmap.

What was the method you used to compute full jacobian? Thanks!

TimoRST commented 5 months ago

Hey @xhsonny, Sorry, I'm just back from vacation today.

I'm not quite sure what you mean, and I dropped the project where I used these methods shortly after the issue. So, I'm not really confident about these parts anymore.

Can you clarify your question with a little bit of context?

Edit: Do you mean the jacobian of the loss function? I didn't. I used an AutoDiffCostFunction: https://github.com/facebookresearch/theseus/blob/e07569138cafa2d5da3e16ff2586d2495e77817c/theseus/core/cost_function.py#L203C7-L203C27