Closed AlphaBetaGamma96 closed 2 years ago
The short answer is that we currently don't allow this use case :P. Vmap expects that your functions are pure, and so it only properly handles values that are returned from the function.
The long answer is that we don't support this use case, but we're quite interested in doing so in the future.
The fundamental reason this is difficult is that allowing this kind of use case breaks the lexical scoping guarantees that make vmap (relatively) easier to implement in the first place. In this case, what's happening is that we're "leaking" the internal BatchedTensors into outside of the vmapped function, and the semantics for this are not (currently) clearly defined.
where the BatchedTensor is clearly of shape [nsamples, ninput] but its shape is labelled as [ninput]
This is intended, and is in fact, how vmap works. Basically, vmap turns tensors into "batched tensors", who look like size [N], but secretly are of size [B, N] (and broadcast all operations across this additional batch dimension).
Finally, to answer this question
Is there a way to grab all the intermediate values?
Yes, but 1. it's a private API, and 2. we don't promise to break your code in the future if you use this API :)
val = []
def f(x):
val.append(x)
return x
vmap(f)(torch.randn(3, 5))
from functorch._C import _remove_batch_dim
print(val[0])
vmap_level = 2
batch_size = 3
out_dim = 0
print(_remove_batch_dim(val[0], vmap_level, batch_size, out_dim))
Thank you for the prompt reply!
Yes, but 1. it's a private API, and 2. we don't promise to break your code in the future if you use this API :)
Thank you for sharing a solution to this problem! I was wondering if I could follow up to just understand what exactly vmap_level
, batch_size
and out_dim
correspond to? I assume batch_size
is the size of the batch (obviously), and out_dim
is the dim
of the batch? And, for vmap_level
I assume this is the level of wrapped BatchTensors? Is there a way to get this in case it changes on a per-case basis? I did notice in the private API there's a functorch._C.maybe_get_level
could this be a solution to that problem?
Thank you once again! :)
I was wondering if I could follow up to just understand what exactly vmap_level, batch_size and out_dim correspond to?
vmap_level
is the level of the wrapped BatchTensors, and is basically an internal detail that allows vmap to compose with itself and grad.
batch_size
is another implementation detail that we might actually be able to remove now...
out_dim
is essentially just an argument that says what the extra "vmap batch dim" should live after vmap.
there's a functorch._C.maybe_get_level could this be a solution to that problem?
I think that should work.
Hello, I've been trying to cache the intermediate values for
nn.Linear
layers but I can't access the intermediate values for all samples. An example script to reproduce this undesired behaviour is below,Ideally, I'd want the inputs for the first layer
fc1
to be of shape[nsamples, ninput]
, however it's just[ninput]
. An example output from this script is as follows,where the
BatchedTensor
is clearly of shape[nsamples, ninput]
but its shape is labelled as[ninput]
. Is there a way to grab all the intermediate values? Perhaps, convert aBatchedTensor
to atorch.Tensor
object? If I try the following I get this error,returns this error,
Thank you!