pytorch / functorch

functorch is JAX-like composable function transforms for PyTorch.
https://pytorch.org/functorch/
BSD 3-Clause "New" or "Revised" License
1.39k stars 102 forks source link

Calculating Jacobian of a model with respect to its parameters? #334

Open mohamad-amin opened 2 years ago

mohamad-amin commented 2 years ago

Hey, I would like to calculate the mentioned jacobians. Right now I'm trying this:

func, params, buffers = make_functional_with_buffers(model)
J = jacrev(lambda p: func(p, buffers, input_dict))(params)

But this gives me the following error:

----> 1 jacrev(lambda p: func(p, buffers, input_dict))(params)

~/.local/lib/python3.8/site-packages/functorch/_src/eager_transforms.py in wrapper_fn(*args)
    356     def wrapper_fn(*args):
    357         f_wrapper, primals = _argnums_partial(f, args, argnums)
--> 358         output, vjp_fn = vjp(f_wrapper, *primals)
    359         assert isinstance(output, torch.Tensor)
    360         # TODO: does jacrev compose with vmap...? the eye call should make it so that it doesn't

~/.local/lib/python3.8/site-packages/functorch/_src/eager_transforms.py in vjp(f, *primals)
    227             primals = _wrap_all_tensors(primals, level)
    228             diff_primals = _create_differentiable(primals, level)
--> 229             primals_out = f(*diff_primals)
    230
    231             results = _undo_create_differentiable(primals_out, level)

~/.local/lib/python3.8/site-packages/functorch/_src/eager_transforms.py in f_wrapper(*wrapper_args)
    405 def _argnums_partial(f, args, argnums):
    406     def f_wrapper(*wrapper_args):
--> 407         replaced_args = _replace_args(args, wrapper_args, argnums)
    408         return f(*replaced_args)
    409     wrapper_args = _slice_argnums(args, argnums)

~/.local/lib/python3.8/site-packages/functorch/_src/eager_transforms.py in _replace_args(old_args, new_args, argnums)
    378             return tuple(new_args[0] if i == argnums else old_args[i] for i in range(len(old_args)))
    379         else:
--> 380             raise RuntimeError(f'new_args should be of size 1, was of size {len(new_args)}')
    381     if isinstance(argnums, tuple):
    382         if len(new_args) == len(argnums):

RuntimeError: new_args should be of size 1, was of size 40

This is kind of maybe expected? As params is a tuple of tensors and not a single tensor. But jax does this by returning a dict (or I've not explored enough to see all the cases).

But anyway, is it possible in functorch to do this? (A possible solution might be to loop over the parameters and do param-wise jacobian, but that gotta be slow I guess, right?)

However, this works for me (slow, but does the job) and I'm using this right now for calculating the jacobians of a model with respect to its parameters:

def calculate_jacobian_wrt_params_torch(model, inputs):
    ggs = []
    model.zero_grad()
    input_dict = to_device({'inputs': inputs.requires_grad_()}, torch.device('cuda'))
    out = model(input_dict)['logits']
    for i in range(out.shape[0]):  # for each datapoint
        gs = []
        for j in range(out.shape[1]):  # for each output neuron
            ps = torch.autograd.grad(out[i, j], model.parameters(), retain_graph=True)
            gs.append(flatten(ps))
        ggs.append(torch.stack(gs))
        del gs
    torch.cuda.empty_cache()
    J = torch.stack(ggs).detach().cpu().numpy()
    return J

For my usecase, I would be super happy with just applying vmap to the inner for loop which iterates over out.shape[1], namely the logits.

P.S: I've tried using torch.functional.jacobian but I get the same error. Maybe I'm missing something here?

Use case: computing neural tangent kernel https://en.wikipedia.org/wiki/Neural_tangent_kernel

mohamad-amin commented 2 years ago

Update:

I tried layer-wise jacobians:

func, params, buffers = make_functional_with_buffers(model)
params = list(params)

def jac_p(p, i):
    params[i] = p
    out = func(params, buffers, input_dict)['logits']
    return out

J0 = jacrev(jac_p)(params[0], 0)

And I got out of memory error, whereas the normal Jacobian using PyTorch autograd.grad runs smoothly.

Model size: ~400k parameters params[0] shape: torch.Size([16, 3, 3, 3]) input size: 100 datapoints of shape 32x32x3 GPU: V100 16GB.

RuntimeError: CUDA out of memory. Tried to allocate 30.52 GiB

mohamad-amin commented 2 years ago

Another update:

func, params, buffers = make_functional_with_buffers(model)
params = list(params)

def jac_p_o(p, i, o):
    params[i] = p
    out = func(params, buffers, input_dict)['logits']
    return out[o]

J0 = jacrev(jac_p_o)(params[0], 0, 0)

Runs, but outputs this a couple of times:

[W BatchedFallback.cpp:108] Warning: There is a performance drop because we have not yet implemented the batching rule for aten::native_batch_norm_backward. Please file us an issue on GitHub so that we can prioritize its implementation. (function warnFallback)

Moreover, this doesn't output what I wanted, it's output is of the shape (10, 16, 3, 3, 3). I have 100 datapoints and I'm calculating the jacobian of the model with respect to the parameters over these inputs (I guess?), which should result in a matrix of shape (100, 16, 3, 3, 3) for one logit if I'm not wrong, right?

zou3519 commented 2 years ago

Hi @mohamad-amin thanks for the issue and we are happy to help. Can you clarify what exactly is the quantity that you're looking to compute? From reading your messages, I think it is the following:

import torch
import torch.nn as nn
from functorch import vmap, jacrev, make_functional_with_buffers
# simple model and data for demonstration
model = nn.Linear(3, 3)
data = torch.randn(100, 3) # 100 datapoints
func, params, buffers = make_functional_with_buffers(model)

# Compute a jacobian of the parameter for each datapoint
result = vmap(jacrev(func), (None, None, 0))(params, buffers, data)

# The jacobian w.r.t to model.weight has size (100, 3, 3, 3)
# - model.weight is 3, 3 and the output shape is (3,) so each jacobian has size 3, 3, 3
# - 100 for the 100 datapoints.
result[0].shape

# The jacobian w.r.t to model.bias has size (100, 3, 3)
# - model.weight is 3 and the output shape is 3 so each jacobian has size 3, 3
# - 100 for the 100 datapoints.
result[1].shape
mohamad-amin commented 2 years ago

Hey,

I'm trying to compute this: image Which requires jacobians of f(X) with respect to the parameters: \nabla_\theta f(X) where X is our datapoints. So in a batched manner, it would be:

Screen Shot 2021-12-13 at 10 13 50 AM

And right now I'm trying to compute the jacobians.

Thanks for the script that you provided. I'll test it later today and check if it computes the thing that I have in my mind (the picture above). But apart from that, why is the code that I've written not working? Isn't it supposed to work as it's simply just jacobian of a function?

Edit: I'm getting the same error with a bigger model: (it's related to params, as it's length is 40)

RuntimeError: new_args should be of size 1, was of size 40

Actually, even your example gives me the same error: (params size is 2 here)

RuntimeError: new_args should be of size 1, was of size 2

zou3519 commented 2 years ago

Actually, even your example gives me the same error: (params size is 2 here)

RuntimeError: new_args should be of size 1, was of size 2

What version of functorch and PyTorch are you running? I tested the example I sent on the latest PyTorch and functorch. (The latest pytorch nightly is 1.11.0.dev20211212 and the latest functorch main)

mohamad-amin commented 2 years ago

This is my pip info:


Name: torch Version: 1.10.0

Name: functorch Version: 0.0.1a0+6107f49

Should I be using the PyTorch nightly build?

Actually, even your example gives me the same error: (params size is 2 here)

RuntimeError: new_args should be of size 1, was of size 2

What version of functorch and PyTorch are you running? I tested the example I sent on the latest PyTorch and functorch. (The latest pytorch nightly is 1.11.0.dev20211212 and the latest functorch main)

Update: I just tried colab instalation provided in PyTorch but it's not working: https://colab.research.google.com/drive/1UrESS8kwpOrlS-QOZLq9y6XwEIvj810H?usp=sharing

Chillee commented 2 years ago

@mohamad-amin Your colab link needs to be shared.

The library is still currently in rapid development, so I wouldn't be shocked if there were fixes between 1.10 and nightly.

mohamad-amin commented 2 years ago

Oh! Sorry about that. updated the permission! @Chillee

zou3519 commented 2 years ago

There were a lot of fixes to functorch w.r.t using make_functional and jacobian computation between 1.10 and the nightly. So yes, you'll have to try out the nightly in order for this to work. Instructions for how to install the nightly can be found here: https://github.com/pytorch/functorch#installing-functorch-main .

zou3519 commented 2 years ago

@mohamad-amin I understand a bit more about neural tangent kernels after reading through some papers. I can probably put together an example in the next week, but I'm curious about the use case -- given a neural network, we can compute its neural tangent kernel, which is some numeric quantity. What do researchers do with these quantities? It looks like some research derive a mathematical expression for the NTK and then use that NTK in a SVM, but this particular case seems different from what we're doing here.

mohamad-amin commented 2 years ago

@zou3519 That's nice! thanks for the effort. So what we're computing here, is the "empirical" ntk, which I think was first derived (or shown to be useful) here: https://arxiv.org/pdf/1902.06720.pdf. Basically, the empirical ntk of a neural network can model the nn's evolution while trained using SGD:

Screen Shot 2022-01-16 at 4 15 48 PM

Here f_t(x) is the actual neural network that we have and f_t^lin(x) is its approximation using Kernel Ridge(-less) regression with the kernel being the empirical NTK computed around the initialization of f_t(x) (initialization referring to the parameters of the network at initialization, the ones that we use to compute the jacobians and NTK):

Screen Shot 2022-01-16 at 4 18 42 PM

So this is just the first order taylor expansion of the neural network (w.r.t it's parameters around initialization). If your network is wide enough, this taylor expansion will be a pretty good approximation of the actual neural network after being trained using SGD for time t. So you basically get the neural network's predictions after training without actually training it.

This can have a lot of use cases, some of which are mentioned here: https://github.com/google/neural-tangents#papers

I think reading that paper that I first mentioned will help a lot to get going with empirical NTKs, but if you wanna know what NTK actually is, I think this is the original NTK paper: https://arxiv.org/abs/1806.07572 (hard to read imo, but brilliant when you get familiar with the notations).

Let me know if there's anything else that I can help with!

zou3519 commented 2 years ago

Hi @mohamad-amin!

Thanks for the references to the papers and for the detailed explanation. I'm still working my way through the papers, but I think I have a working example of computing the empirical NTK.

I'm a bit confused about what the actual shape of the empirical NTK is. Let's say I had a simple nn.Linear(5, 7, bias=False) layer, and x1 have shape (3, 5) and x2 have shape (4, 5) (where x1, x2 are batches of examples that are shape (5,)).

Then the jacobian of each example has shape (7, 5). If we're computing image then the NTK between two examples should have shape (7, 7). This means the final NTK should have shape (3, 4, 7, 7).

Is that correct, or is the reduction done over the entire jacobian, resulting in an empirical NTK of shape (3, 4)?

mohamad-amin commented 2 years ago

Hey,

Great, looking forward to it! I don't think any of these two shapes can be called the "correct" shape. In fact, they're both used in literature (both empirical and theoretical papers, i.e. https://arxiv.org/abs/1904.11955 mentions n x n shape and https://arxiv.org/abs/1806.07572 mentions nk x nk shape where k is the number of last layer's neurons).

I'm not too confident about this, but as far as I know, it depends on (how/if) you contract the jacobian of f w.r.t it's parameters for point x. You might or might not want to compute the trace over the computed (7, 7) jacobian for each pair of points. Intuitively, you'll lose some precision when tracing over this matrix for empirical neural tangent kernel, but as mentioned here (in the description for trace_axes), it shouldn't be very significant as long as in the infinite width limit your network converges to have onehot outputs (proven to be the case for most of the modern NN architectures in https://arxiv.org/pdf/2105.03703.pdf and https://arxiv.org/pdf/2006.14548.pdf).

This is my personal understanding that might or might not be helpful: I view NTK as a specific covariance matrix between these two inputs that captures the infinite network's embeddings somehow. Later, this covariance matrix can be used as the kernel in kernel ridge regression, which as proven, will result in the same predictions as the infinite neural network. In this point of view, when you reduce over the axes related to last layer's neurons (which are 7 neurons in your example), you're summing over the trace of the covariance matrix, thus, you can't answer the question "what is the covariance between label f(x)_iand f(x')_j anymore, but you just have a scalar to show the covariance between f(x) and f(x')". Thus, in some cases, when using this kernel, your approximations will become less accurate (*), but for example for FCNs it can be proved that in the infinite width limit your approximation will be the same whether you compute the trace or you don't, as the last layer's outputs converge to be onehot in that limit.

*: I've read somewhere that as long as your network's outputs are going to converge to onehot in the infinite width limit, even for your finite width approximations with empirical ntk, tracing over the covariance results in better accuracy, but I'm not convinced that this is true.

zou3519 commented 2 years ago

Hey @mohamad-amin, I just wanted to give a quick status update (sorry for leaving you hanging!). We've been stuck on a bug https://github.com/pytorch/functorch/issues/417 that makes the NTK computation error out in functorch, but we're working through it :). Hopefully will fix the bug sometime next week.

zou3519 commented 2 years ago

@mohamad-amin we finally fixed https://github.com/pytorch/functorch/issues/417! Here's the first version of the example of how to compute NTKs with functorch. https://github.com/pytorch/functorch/blob/cb876fad2b2a9269424c8212a82652f372db6dfa/notebooks/neural_tangent_kernels.ipynb .

The code runs on the latest build of functorch (see "functorch main" in https://github.com/pytorch/functorch#installing-functorch-main for installation instructions) if you wanted to give it a try.

We would greatly appreciate your feedback on the example! Ultimately I'd like to add it to our website as one of the tutorials (https://pytorch.org/functorch/nightly/).

ain-soph commented 2 years ago

@zou3519 Thanks for your tutorial on NTK computation.

Following your procedures, I also implement the PyTorch version without using functorch by calling torch.nn.utils._stateless.functional_call. I'm curious about what's the advantage of using functorch. (mentioned in #788)

Besides, to make the tutorial more complete, it would be extremely nice if you could add the part of computing

image

, i.e., a linear approximation of model at training epoch t. (The formula is illustrated above in those papers)


import torch
import torch.nn as nn
from torch.nn.utils import _stateless

import functools

def ntk(module: nn.Module, input1: torch.Tensor, input2: torch.Tensor,
        parameters: dict[str, nn.Parameter] = None,
        compute='full') -> torch.Tensor:
    einsum_expr: str = ''
    match compute:
        case 'full':
            einsum_expr = 'Naf,Mbf->NMab'
        case 'trace':
            einsum_expr = 'Naf,Maf->NM'
        case 'diagonal':
            einsum_expr = 'Naf,Maf->NMa'
        case _:
            raise ValueError(compute)

    if parameters is None:
        parameters = dict(module.named_parameters())
    keys, values = zip(*parameters.items())

    def func(*params: torch.Tensor, _input: torch.Tensor = None):
        _output: torch.Tensor = _stateless.functional_call(
            module, {n: p for n, p in zip(keys, params)}, _input)
        return _output  # (N, C)

    jac1: tuple[torch.Tensor] = torch.autograd.functional.jacobian(
        functools.partial(func, _input=input1), values, vectorize=True)
    jac2: tuple[torch.Tensor] = torch.autograd.functional.jacobian(
        functools.partial(func, _input=input2), values, vectorize=True)
    jac1 = [j.flatten(2) for j in jac1]
    jac2 = [j.flatten(2) for j in jac2]
    result = torch.stack([torch.einsum(einsum_expr, j1, j2) for j1, j2 in zip(jac1, jac2)]).sum(0)
    return result
Chillee commented 2 years ago

@ain-soph Well, you're only computing the explicit empirical NTK in this example :P You may find computing the implicit NTK harder with existing PyTorch APIs.

That being said, in many cases, you can use both functorch and PyTorch core to compute the same things. It's up to you which API you prefer more. Personally, I like functorch (i.e. Jax)-style autograd APIs when thinking about complicated gradient quantities, while I prefer PyTorch's imperative AD API for more traditional neural networks. The goal of functorch is to give you that choice (and also give you vmap :P)!

I will note that the vectorize=True flag in PyTorch is (essentially) using an earlier version of functorch with less operator coverage.

ain-soph commented 2 years ago

@Chillee I gradually understand that you are correct.
Without using vmap, I have to calculate multiple JVP and VJPs in a for loop in implicit NTK algorithm.

Fangwq commented 1 year ago

@ain-soph, thanks for sharing your codes. However, the codes get so different results than the torch ntk tutorial. Is that normal?

ain-soph commented 1 year ago

@ain-soph, thanks for sharing your codes. However, the codes get so different results than the torch ntk tutorial. Is that normal?

@Fangwq I just tried the codes compared with ntk tutorial.

result1 = empirical_ntk_jacobian_contraction(fnet_single, params, x_train, x_test)
result2 = ntk(net, x_train, x_test)

print((result1-result2).abs().max())
print((result1-result2).abs().sum())

Output:

tensor(7.6294e-06, device='cuda:0', grad_fn=<MaxBackward1>)
tensor(0.0035, device='cuda:0', grad_fn=<SumBackward0>)

I don't think it's very different result.