Is there a way of vmapping over the selection of passing indices within a Tensor? Minimal reproducible example below,
import torch
from functorch import vmap
def select(x, index):
print(x.shape, index.shape)
return x[index]
x = torch.randn(64, 1000) #64 vectors of length 1000
index=torch.arange(64) #index for each vector
out = vmap(select, in_dims=(0, 0))(x, index) #vmap over the process
print(out) #should output vector of 64
This should take in a batches of vectors and select the corresponding index from index vector (which can be viewed as a batch of scalars, and is hence represented as a vector).
The error is as follows,
RuntimeError: vmap: It looks like you're calling .item() on a Tensor. We don't support vmap over calling .item() on a Tensor, please try to rewrite what you're doing with other operations. If error is occurring somewhere inside PyTorch internals, please file a bug report.
I tried using torch.select but that requires passing the index as an int rather than Tensor so it must call .item() interally. Is there a workaround that already exists for this?
Hi,
Is there a way of vmapping over the selection of passing indices within a Tensor? Minimal reproducible example below,
This should take in a batches of vectors and select the corresponding index from
index
vector (which can be viewed as a batch of scalars, and is hence represented as a vector).The error is as follows,
I tried using
torch.select
but that requires passing the index as anint
rather thanTensor
so it must call.item()
interally. Is there a workaround that already exists for this?