pytorch / functorch

functorch is JAX-like composable function transforms for PyTorch.
https://pytorch.org/functorch/
BSD 3-Clause "New" or "Revised" License
1.38k stars 102 forks source link

vmap equivalent for tensor[indices] #1074

Closed AlphaBetaGamma96 closed 1 year ago

AlphaBetaGamma96 commented 1 year ago

Hi,

Is there a way of vmapping over the selection of passing indices within a Tensor? Minimal reproducible example below,

import torch
from functorch import vmap

def select(x, index):
  print(x.shape, index.shape)
  return x[index]

x = torch.randn(64, 1000) #64 vectors of length 1000
index=torch.arange(64)   #index for each vector

out = vmap(select, in_dims=(0, 0))(x, index) #vmap over the process
print(out) #should output vector of 64 

This should take in a batches of vectors and select the corresponding index from index vector (which can be viewed as a batch of scalars, and is hence represented as a vector).

The error is as follows,

RuntimeError: vmap: It looks like you're calling .item() on a Tensor. We don't support vmap over calling .item() on a Tensor, please try to rewrite what you're doing with other operations. If error is occurring somewhere inside PyTorch internals, please file a bug report.

I tried using torch.select but that requires passing the index as an int rather than Tensor so it must call .item() interally. Is there a workaround that already exists for this?

AlphaBetaGamma96 commented 1 year ago

Just figured out what you can do,

def gather(x, index):
  return torch.gather(x, 0, index)

out = vmap(gather, in_dims=(0, 0))(x, index)
print(out) #returns vector of 64