facebookresearch / pytorch3d

PyTorch3D is FAIR's library of reusable components for deep learning with 3D data
https://pytorch3d.org/
Other
8.49k stars 1.28k forks source link

Any plan to support `torch.func`? #1796

Open xhsonny opened 1 month ago

xhsonny commented 1 month ago

I am trying to compute full jacobian using jacrev or jacfwd from torch.func. Part of the loss function uses _PointFaceDistance. Out of the box, pytorch3d does not support torch.func. The closest references I can find so far are https://github.com/facebookresearch/pytorch3d/issues/1636 and https://github.com/facebookresearch/pytorch3d/issues/1533.

The problems I am having are

The only working method is to call torch.autograd.functional.jacobian(vectorize=False) which is very slow. And when turn on vectorize=True, it runs into the same issues as above.

My questions are:

  1. is there a plan to officially support torch.func ? If I can get some guidance from pytorch3d team, I am happy to collaborate on this.
  2. Any idea how to make this work? Any workarounds?

Thanks!

bottler commented 1 month ago

We aren't planning torch.func support. It seems to me that the method in #1533 should work fine for _PointFaceDistance - feel free to post the code you have and maybe we can figure out what's wrong.

xhsonny commented 1 month ago

@bottler Thanks for the reply!

The error msg "RuntimeError: Cannot access data pointer of Tensor that doesn't have storage" from #1533 happened in the forward pass "idx, dists = _C.knn_points_idx(p1, p2, lengths1, lengths2, norm, K, version). I don't know how #1533 is computing the jacobian. In my case, neither forward mode nor backward mode work with following errors:

I did follow #1533 to modify the class with setup_context etc. I also added vmap but vmap method is never called before hitting two errors above.

xhsonny commented 1 month ago

@bottler I wrote a toy example that follows this pytorch3d tutorial with following modifications:

  1. Only use _PointFaceDistance in the objective. Because I only care if we can compute jacobian, it does not matter if the optimization actually runs.
  2. Added vmap to _PointFaceDistance and added setup_context.
  3. Used theseus

Here is the code. It is self-contained and will download dolphin.obj following the pytorch3d tutorial. Sorry that the code is a bit long to include _PointFaceDistance updates.

import os
import urllib.request

import einops
import theseus as th
import torch
from pytorch3d import _C
from pytorch3d.io import load_obj
from pytorch3d.structures import Meshes
from pytorch3d.utils import ico_sphere
from torch.autograd import Function
from torch.autograd.function import once_differentiable

_DEFAULT_MIN_TRIANGLE_AREA: float = 5e-3

# PointFaceDistance
class _PointFaceDistance(Function):
    """
    Torch autograd Function wrapper PointFaceDistance Cuda implementation
    """

    generate_vmap_rule = False

    @staticmethod
    def forward(
        # ctx,
        points,
        points_first_idx,
        tris,
        tris_first_idx,
        max_points,
        min_triangle_area=_DEFAULT_MIN_TRIANGLE_AREA,
    ):
        """
        Args:
            ctx: Context object used to calculate gradients.
            points: FloatTensor of shape `(P, 3)`
            points_first_idx: LongTensor of shape `(N,)` indicating the first point
                index in each example in the batch
            tris: FloatTensor of shape `(T, 3, 3)` of triangular faces. The `t`-th
                triangular face is spanned by `(tris[t, 0], tris[t, 1], tris[t, 2])`
            tris_first_idx: LongTensor of shape `(N,)` indicating the first face
                index in each example in the batch
            max_points: Scalar equal to maximum number of points in the batch
            min_triangle_area: (float, defaulted) Triangles of area less than this
                will be treated as points/lines.
        Returns:
            dists: FloatTensor of shape `(P,)`, where `dists[p]` is the squared
                euclidean distance of `p`-th point to the closest triangular face
                in the corresponding example in the batch
            idxs: LongTensor of shape `(P,)` indicating the closest triangular face
                in the corresponding example in the batch.

            `dists[p]` is
            `d(points[p], tris[idxs[p], 0], tris[idxs[p], 1], tris[idxs[p], 2])`
            where `d(u, v0, v1, v2)` is the distance of point `u` from the triangular
            face `(v0, v1, v2)`

        """
        dists, idxs = _C.point_face_dist_forward(
            points,
            points_first_idx,
            tris,
            tris_first_idx,
            max_points,
            min_triangle_area,
        )
        # ctx.save_for_backward(points, tris, idxs)
        # ctx.min_triangle_area = min_triangle_area
        return dists, idxs

    @staticmethod
    def setup_context(ctx, inputs, output):
        (
            points,
            points_first_idx,
            tris,
            tris_first_idx,
            max_tris,
            min_triangle_area,
        ) = inputs
        dists, idxs = output
        ctx.save_for_backward(points, tris, idxs)
        ctx.min_triangle_area = min_triangle_area
        ctx.dists = dists
        ctx.idxs = idxs
        ctx.inputs = inputs

    @staticmethod
    @once_differentiable
    def backward(ctx, grad_dists, grad_idxs):
        grad_dists = grad_dists.contiguous()
        points, tris, idxs = ctx.saved_tensors
        min_triangle_area = ctx.min_triangle_area
        grad_points, grad_tris = _C.point_face_dist_backward(
            points, tris, idxs, grad_dists, min_triangle_area
        )
        return grad_points, None, grad_tris, None, None, None

    @staticmethod
    def vmap(
        info,
        in_idms,
        points,
        points_first_idx,
        tris,
        tris_first_idx,
        max_points,
        min_triangle_area=_DEFAULT_MIN_TRIANGLE_AREA,
    ):
        (
            points_bdim,
            points_first_idx_bdim,
            tris_bdm,
            tris_first_idx_bdim,
            _,
            _,
        ) = in_idms

        points_V, points_P, points_C = points.shape
        points = einops.rearrange(points, "V P C -> (V P) C")

        tris = einops.rearrange(tris, "V T A B -> (V T) A B")

        dists, idx = _PointFaceDistance.forward(
            points,
            points_first_idx,
            tris,
            tris_first_idx,
            max_points,
            min_triangle_area,
        )
        dists = einops.rearrange(dists, "(V P) -> V P", V=points_V)
        idx = einops.rearrange(idx, "(V P) -> V P", V=points_V)
        return (dists, idx), (0, 0)

point_face_distance = _PointFaceDistance.apply

def point_to_mesh_distance(points, mesh_v, mesh_f):

    scale_fac = 100.0  # see explanation above

    # packed representation for pointclouds
    points = points * scale_fac  # (P, 3)
    points_first_idx = torch.zeros([1])
    max_points = points.shape[0]

    # packed representation for faces
    verts_packed = mesh_v * scale_fac
    faces_packed = mesh_f
    tris = verts_packed[faces_packed.to(torch.int)]
    tris_first_idx = torch.zeros([1])

    point_to_face, _ = point_face_distance(
        points.to(torch.float32),
        points_first_idx.to(torch.long),
        tris.to(torch.float32),
        tris_first_idx.to(torch.long),
        max_points,
    )
    point_to_face = point_to_face / (scale_fac**2)
    return torch.sqrt(point_to_face)

device = "cpu"
target_obj_path = "dolphin.obj"
if not os.path.exists(target_obj_path):
    # Reference: https://pytorch3d.org/tutorials/deform_source_mesh_to_target_mesh
    src_url = (
        "https://dl.fbaipublicfiles.com/pytorch3d/data/dolphin/dolphin.obj"
    )
    print(f"Downloading from {src_url}")
    urllib.request.urlretrieve(
        src_url,
        "dolphin.obj",
    )
verts, faces, aux = load_obj(target_obj_path)
faces_idx = faces.verts_idx.to(device)
verts = verts.to(device)

center = verts.mean(0)
verts = verts - center
scale = max(verts.abs().max(0)[0])
verts = verts / scale

target_mesh = Meshes(verts=[verts], faces=[faces_idx])
src_mesh = ico_sphere(4, device)

deform_verts = th.Vector(
    tensor=src_mesh.verts_packed().reshape(1, -1),
    name="deform_v",
)

target_v = th.Variable(verts.reshape(1, -1), name="target_v")
faces_idx = faces_idx.to(torch.float32)
target_f = th.Variable(faces_idx.reshape(1, -1), name="target_f")

def error_fn(optim_vars, aux_vars):
    (verts,) = optim_vars
    target_v, target_f = aux_vars
    p2m = point_to_mesh_distance(
        verts.tensor.reshape(-1, 3).to(torch.float32),
        mesh_v=target_v.tensor.reshape(-1, 3).to(torch.float32),
        mesh_f=target_f.tensor.reshape(-1, 3),
    ).to(torch.float64)
    return p2m.unsqueeze(0)

optim_vars = (deform_verts,)
aux_vars = target_v, target_f
cost_function = th.AutoDiffCostFunction(
    optim_vars,
    error_fn,
    deform_verts.shape[1] / 3,
    aux_vars=aux_vars,
    name="l2",
)

# grad_points, grad_tris = _C.point_face_dist_backward(
# RuntimeError: Cannot access data pointer of Tensor that doesn't have storage
cost_function.jacobians()

When cost_function.jacobians() is called, it throws an exception. Full error below:

Traceback (most recent call last):
  File "/Users/sonny/jac_theseus.py", line 227, in <module>
    cost_function.jacobians()
  File "/usr/local/anaconda3/envs/pytorch3d_torch2_v2/lib/python3.10/site-packages/theseus/core/cost_function.py", line 355, in jacobians
    jacobians_full = self._compute_autograd_jacobian_vmap(
  File "/usr/local/anaconda3/envs/pytorch3d_torch2_v2/lib/python3.10/site-packages/theseus/core/cost_function.py", line 341, in _compute_autograd_jacobian_vmap
    return vmap(jacrev(jac_fn, argnums=0))(optim_tensors, aux_tensors)
  File "/usr/local/anaconda3/envs/pytorch3d_torch2_v2/lib/python3.10/site-packages/torch/_functorch/apis.py", line 188, in wrapped
    return vmap_impl(func, in_dims, out_dims, randomness, chunk_size, *args, **kwargs)
  File "/usr/local/anaconda3/envs/pytorch3d_torch2_v2/lib/python3.10/site-packages/torch/_functorch/vmap.py", line 278, in vmap_impl
    return _flat_vmap(
  File "/usr/local/anaconda3/envs/pytorch3d_torch2_v2/lib/python3.10/site-packages/torch/_functorch/vmap.py", line 44, in fn
    return f(*args, **kwargs)
  File "/usr/local/anaconda3/envs/pytorch3d_torch2_v2/lib/python3.10/site-packages/torch/_functorch/vmap.py", line 391, in _flat_vmap
    batched_outputs = func(*batched_inputs, **kwargs)
  File "/usr/local/anaconda3/envs/pytorch3d_torch2_v2/lib/python3.10/site-packages/torch/_functorch/eager_transforms.py", line 609, in wrapper_fn
    flat_jacobians_per_input = compute_jacobian_stacked()
  File "/usr/local/anaconda3/envs/pytorch3d_torch2_v2/lib/python3.10/site-packages/torch/_functorch/eager_transforms.py", line 540, in compute_jacobian_stacked
    chunked_result = vmap(vjp_fn)(basis)
  File "/usr/local/anaconda3/envs/pytorch3d_torch2_v2/lib/python3.10/site-packages/torch/_functorch/apis.py", line 188, in wrapped
    return vmap_impl(func, in_dims, out_dims, randomness, chunk_size, *args, **kwargs)
  File "/usr/local/anaconda3/envs/pytorch3d_torch2_v2/lib/python3.10/site-packages/torch/_functorch/vmap.py", line 278, in vmap_impl
    return _flat_vmap(
  File "/usr/local/anaconda3/envs/pytorch3d_torch2_v2/lib/python3.10/site-packages/torch/_functorch/vmap.py", line 44, in fn
    return f(*args, **kwargs)
  File "/usr/local/anaconda3/envs/pytorch3d_torch2_v2/lib/python3.10/site-packages/torch/_functorch/vmap.py", line 391, in _flat_vmap
    batched_outputs = func(*batched_inputs, **kwargs)
  File "/usr/local/anaconda3/envs/pytorch3d_torch2_v2/lib/python3.10/site-packages/torch/_functorch/eager_transforms.py", line 336, in wrapper
    result = _autograd_grad(flat_primals_out, flat_diff_primals, flat_cotangents,
  File "/usr/local/anaconda3/envs/pytorch3d_torch2_v2/lib/python3.10/site-packages/torch/_functorch/eager_transforms.py", line 124, in _autograd_grad
    grad_inputs = torch.autograd.grad(diff_outputs, inputs, grad_outputs,
  File "/usr/local/anaconda3/envs/pytorch3d_torch2_v2/lib/python3.10/site-packages/torch/autograd/__init__.py", line 411, in grad
    result = Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
  File "/usr/local/anaconda3/envs/pytorch3d_torch2_v2/lib/python3.10/site-packages/torch/autograd/function.py", line 289, in apply
    return user_fn(self, *args)
  File "/usr/local/anaconda3/envs/pytorch3d_torch2_v2/lib/python3.10/site-packages/torch/_functorch/autograd_function.py", line 123, in backward
    result = autograd_function.backward(ctx, *grads)
  File "/usr/local/anaconda3/envs/pytorch3d_torch2_v2/lib/python3.10/site-packages/torch/autograd/function.py", line 570, in wrapper
    outputs = fn(ctx, *args)
  File "/Users/sonny/jac_theseus.py", line 96, in backward
    grad_points, grad_tris = _C.point_face_dist_backward(
RuntimeError: Cannot access data pointer of Tensor that doesn't have storage

I don't know how to start debugging this as _C.point_face_dist_backward happens in CUDA/CPU code. If you have any pointers, please let me know. Thanks a lot!

TimoRST commented 1 month ago

Just in intuition based on the traceback, especially here:

File "/usr/local/anaconda3/envs/pytorch3d_torch2_v2/lib/python3.10/site-packages/theseus/core/cost_function.py", line 341, in _compute_autograd_jacobian_vmap return vmap(jacrev(jac_fn, argnums=0))(optim_tensors, aux_tensors)

I think that jac_fn will be called in the same way that the forward path is called with vmap. So, there might be an error in the jac_fn function's vmapping, which you might specify in the same way as you did for the forward path.

xhsonny commented 1 month ago

@TimoRST Thanks for the inputs! I never thought of it and will look into it.

Just one follow up, jac_fn is either the cost_function which is th.AutoDiffCostFunction or the error_fn I wrote that I used to compute the loss values. Do you know how to add vmap to those functions? They are not torch.autograd.Function.

Thanks!

TimoRST commented 1 month ago

I don't know that. I would first debug into that function to see if it really is called with those batched tensors which cause the error. After verifying you might just make an autodiff function out of that function?

xhsonny commented 1 month ago

@TimoRST Thanks. I will try it out.

Can I ask what your use case was to use knn and theseus? Was it also used in the context of an optimization that needed a full jacobian?

Thanks!

TimoRST commented 1 month ago

I wanted to implement something like WGICP (https://arxiv.org/abs/2209.09777), but I couldn't scale it because my graphics card was too small, so I didn't get comparable results. I didn't use the Jacobian, so it was enough to ensure correct vmapping in the forward path.

xhsonny commented 1 month ago

@TimoRST Thanks for the details. Really appreciate the help!

@bottler could you take a look at the code I posted above? I am going to follow @TimoRST suggestion to take a look at the error function. Meanwhile you find anything in my code, please let me know. Thanks!

xhsonny commented 1 month ago

Here are some findings. Conclusion: _C.point_face_dist_backward cannot take in grad_dists that is a 2D BatchedVector which is created by vmap/jacrev.

Here is a hacked version to verify my point though the math is probably wrong. When using jacrev, "un-vmap" the v_grad_dists and hack it with a for loop to compute row-by-row. It works fine. Then when it returns, pytorch3d complains the returned grad_points has wrong shape.

 @staticmethod
    @once_differentiable
    def backward(ctx, grad_dists, grad_idxs):
        grad_dists = grad_dists.contiguous()
        points, tris, idxs = ctx.saved_tensors
        min_triangle_area = ctx.min_triangle_area

        # https://discuss.pytorch.org/t/save-batchedtensor-to-a-pickle-file/170561/4
        v_points = torch._C._functorch.get_unwrapped(points)
        v_tris = torch._C._functorch.get_unwrapped(tris)
        v_idxs = torch._C._functorch.get_unwrapped(idxs)
        v_grad_dists = torch._C._functorch.get_unwrapped(grad_dists)

        grad_points = []
        grad_tris = None
        for v_grad_dists_v in v_grad_dists:
           v_grad_points, v_grad_tris = _C.point_face_dist_backward(
               v_points, v_tris, v_idxs, v_grad_dists_v, min_triangle_area
           )
           grad_points.append(v_grad_points)
           if grad_tris is not None:
               grad_tris = grad_tris + v_grad_tris
           else:
               grad_tris = v_grad_tris

        grad_points = torch.cat(grad_points, dim=1)
        return grad_points, None, grad_tris, None, None, None

@bottler would you be able to confirm what I said above is correct? And if this is the case, it seems changing internal code of _C.point_face_dist_backward is the only option?

Thanks!