pytorch / functorch

functorch is JAX-like composable function transforms for PyTorch.
https://pytorch.org/functorch/
BSD 3-Clause "New" or "Revised" License
1.38k stars 102 forks source link

Get .item() error without calling .item() #1024

Open LiXinrong1012 opened 1 year ago

LiXinrong1012 commented 1 year ago

Hello guys, I'm new to this package and I want to calculate batched Jacobian w.r.t a self-implemented vector function. But I got the following error when I'm doing this.

RuntimeError: vmap: It looks like you're calling .item() on a Tensor. We don't support vmap over calling .item() on a Tensor, please try to rewrite what you're doing with other operations. If error is occurring somewhere inside PyTorch internals, please file a bug report.

Here is my code. I don't understand where the .item() comes from. Is this slicing operation q_current[0:3] wrong? How can I fix this?

import torch
from functorch import jacrev,vmap

#batch * len
q_current = torch.randn((4,4*3-1),requires_grad=True)

def geoCompute(q_current):
    k1 = q_current[0:3]
    return k1

jacobian = vmap(jacrev(geoCompute))(q_current)
AlphaBetaGamma96 commented 1 year ago

What version of functorch are you using? Your minimal reproducible script works for me. It returns,

tensor([[[1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
         [0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
         [0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0.]],

        [[1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
         [0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
         [0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0.]],

        [[1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
         [0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
         [0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0.]],

        [[1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
         [0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
         [0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0.]]])