pytorch / functorch

functorch is JAX-like composable function transforms for PyTorch.
https://pytorch.org/functorch/
BSD 3-Clause "New" or "Revised" License
1.4k stars 102 forks source link

:bug: register_foward_pre_hook returns incorrect values with vmap #488

Closed AlphaBetaGamma96 closed 2 years ago

AlphaBetaGamma96 commented 2 years ago

Hello, I've been trying to cache the intermediate values for nn.Linear layers but I can't access the intermediate values for all samples. An example script to reproduce this undesired behaviour is below,

import torch
from torch import nn
from functorch import make_functional, vmap

class network(nn.Module):

  def __init__(self, ninput, nhidden, noutput):
    super(network, self).__init__()

    self.fc1 = nn.Linear(ninput, nhidden)
    self.fc2 = nn.Linear(nhidden, noutput)

    self.af = nn.Tanh()

  def forward(self, x):
    x = self.fc1(x)
    x = self.af(x)
    x = self.fc2(x)
    return x

nsamples=4096
ninput=4
nhidden=32
noutput=1

model = network(ninput=ninput,
                nhidden=nhidden,
                noutput=noutput)

#forward-pre hook and backward full hook
hooks = {}
def _save_input(module, input) -> None:
  hooks['a'] = input[0]

def _save_output(module, grad_input, grad_output) -> None:
  hooks['e'] = grad_output[0]

#hook for first layer fc1
model.fc1.register_forward_pre_hook(_save_input)
model.fc1.register_full_backward_hook(_save_output)

x = torch.randn(nsamples, ninput) #input data

func_model, params = make_functional(model)

output = vmap(func_model, (None, 0))(params, x)

print("activations: ",hooks['a'])
print("activations shape: ",hooks['a'].shape)

Ideally, I'd want the inputs for the first layer fc1 to be of shape [nsamples, ninput], however it's just [ninput]. An example output from this script is as follows,

activations:  BatchedTensor(lvl=2, bdim=0, value=
    tensor([[ 0.0341,  1.7255, -0.4531,  0.7700],
            [-1.2910,  0.5553,  0.2371, -0.3498],
            [ 2.0675, -0.4086,  1.6362, -0.3628],
            ...,
            [ 0.2447,  1.0775,  0.2377, -0.4810],
            [-0.7522, -0.9259,  1.1660, -1.4304],
            [-0.6454, -0.4539, -1.0739, -0.0787]])
)
activations shape:  torch.Size([4])

where the BatchedTensor is clearly of shape [nsamples, ninput] but its shape is labelled as [ninput]. Is there a way to grab all the intermediate values? Perhaps, convert a BatchedTensor to a torch.Tensor object? If I try the following I get this error,

print("activations shape: ",torch.Tensor(hooks['a']).shape)

returns this error,

RuntimeError: batched == nullptrINTERNAL ASSERT FAILED at "/tmp/pip-req-build-4bwnfz__/functorch/csrc/DynamicLayer.cpp":367, please report a bug to PyTorch. 

Thank you!

Chillee commented 2 years ago

The short answer is that we currently don't allow this use case :P. Vmap expects that your functions are pure, and so it only properly handles values that are returned from the function.

The long answer is that we don't support this use case, but we're quite interested in doing so in the future.

The fundamental reason this is difficult is that allowing this kind of use case breaks the lexical scoping guarantees that make vmap (relatively) easier to implement in the first place. In this case, what's happening is that we're "leaking" the internal BatchedTensors into outside of the vmapped function, and the semantics for this are not (currently) clearly defined.

where the BatchedTensor is clearly of shape [nsamples, ninput] but its shape is labelled as [ninput]

This is intended, and is in fact, how vmap works. Basically, vmap turns tensors into "batched tensors", who look like size [N], but secretly are of size [B, N] (and broadcast all operations across this additional batch dimension).

Finally, to answer this question

Is there a way to grab all the intermediate values?

Yes, but 1. it's a private API, and 2. we don't promise to break your code in the future if you use this API :)

val = []
def f(x):
  val.append(x)
  return x

vmap(f)(torch.randn(3, 5))
from functorch._C import _remove_batch_dim
print(val[0])
vmap_level = 2
batch_size = 3
out_dim = 0
print(_remove_batch_dim(val[0], vmap_level, batch_size, out_dim))
AlphaBetaGamma96 commented 2 years ago

Thank you for the prompt reply!

Yes, but 1. it's a private API, and 2. we don't promise to break your code in the future if you use this API :)

Thank you for sharing a solution to this problem! I was wondering if I could follow up to just understand what exactly vmap_level, batch_size and out_dim correspond to? I assume batch_size is the size of the batch (obviously), and out_dim is the dim of the batch? And, for vmap_level I assume this is the level of wrapped BatchTensors? Is there a way to get this in case it changes on a per-case basis? I did notice in the private API there's a functorch._C.maybe_get_level could this be a solution to that problem?

Thank you once again! :)

Chillee commented 2 years ago

I was wondering if I could follow up to just understand what exactly vmap_level, batch_size and out_dim correspond to?

vmap_level is the level of the wrapped BatchTensors, and is basically an internal detail that allows vmap to compose with itself and grad.

batch_size is another implementation detail that we might actually be able to remove now...

out_dim is essentially just an argument that says what the extra "vmap batch dim" should live after vmap.

there's a functorch._C.maybe_get_level could this be a solution to that problem?

I think that should work.