pytorch / functorch

functorch is JAX-like composable function transforms for PyTorch.
https://pytorch.org/functorch/
BSD 3-Clause "New" or "Revised" License
1.38k stars 102 forks source link

Can I call torch.utils.data.WeightedRandomSampler inside vmap? #1123

Closed kwmaeng91 closed 1 year ago

kwmaeng91 commented 1 year ago

Dear Experts,

I am trying to accelerate a series of weighted sampling (i.e., transition using a stochastic matrix) using vmap. Basically, I am trying to accelerate the code from here: https://discuss.pytorch.org/t/best-way-to-implement-series-of-weighted-random-sampling-for-transition-w-stochastic-matrix/176713 using vmap instead of a for loop, by calling torch.utils.data.WeightedRandomSamper() inside vmap (the link is my question asking for any alternative way for acceleration in the general forum). However, I get an error and I am not sure if this is possible.

Below is my code:

import torch
from torch import nn
from functorch import vmap

N = 10
M = 20
L = 5

P = torch.rand([N, M])
x = torch.randint(0, N, [L])
P_new = torch.stack([P[x[i]] for i in range(L)])

f = lambda p: torch.tensor(list(torch.utils.data.WeightedRandomSampler(p, 1))[0])
y = vmap(f, randomness='different')(P_new)

print(y)

Ideally, I want to sample L elements, each using distribution P[x[i]] for i = range(L). Below is the error I get:

Traceback (most recent call last):
  File "xxx/test.py", line 17, in <module>
    y = vmap(f, randomness='different')(P_new)
  File "xxx/functorch/_src/vmap.py", line 361, in wrapped
    return _flat_vmap(
  File "xxx/functorch/_src/vmap.py", line 487, in _flat_vmap
    batched_outputs = func(*batched_inputs, **kwargs)
  File xxx/test.py", line 16, in <lambda>
    f = lambda p: torch.tensor(list(torch.utils.data.WeightedRandomSampler(p, 1))[0])
  File "xxxx/site-packages/torch/utils/data/sampler.py", line 203, in __iter__
    yield from iter(rand_tensor.tolist())
RuntimeError: Cannot access data pointer of Tensor that doesn't have storage

I wonder if something like this is fundamentally impossible, or is there a way around my error.

Any help would be highly appreciated! Thank you

kwmaeng91 commented 1 year ago

Actually, I figured out that a function called torch.multinomial() is doing exactly what I want to do, so no need the complicated combination of vmap + WeightedRandomSampler to work anymore.

Still, I would be interested to learn what the above error is about.