Closed kwmaeng91 closed 1 year ago
Actually, I figured out that a function called torch.multinomial() is doing exactly what I want to do, so no need the complicated combination of vmap + WeightedRandomSampler to work anymore.
Still, I would be interested to learn what the above error is about.
Dear Experts,
I am trying to accelerate a series of weighted sampling (i.e., transition using a stochastic matrix) using vmap. Basically, I am trying to accelerate the code from here: https://discuss.pytorch.org/t/best-way-to-implement-series-of-weighted-random-sampling-for-transition-w-stochastic-matrix/176713 using vmap instead of a for loop, by calling torch.utils.data.WeightedRandomSamper() inside vmap (the link is my question asking for any alternative way for acceleration in the general forum). However, I get an error and I am not sure if this is possible.
Below is my code:
Ideally, I want to sample L elements, each using distribution P[x[i]] for i = range(L). Below is the error I get:
I wonder if something like this is fundamentally impossible, or is there a way around my error.
Any help would be highly appreciated! Thank you