pytorch / functorch

functorch is JAX-like composable function transforms for PyTorch.
https://pytorch.org/functorch/
BSD 3-Clause "New" or "Revised" License
1.38k stars 102 forks source link

How to get only the last few layers' gradident? #1101

Open pmzzs opened 1 year ago

pmzzs commented 1 year ago
from functorch import make_functional_with_buffers, vmap, grad
fmodel, params, buffers = make_functional_with_buffers(net,disable_autograd_tracking=True)

def compute_loss_stateless_model (params, buffers, sample, target):
    batch = sample.unsqueeze(0)
    targets = target.unsqueeze(0)

    predictions = fmodel(params, buffers, batch) 
    loss = criterion(predictions, targets)
    return loss

ft_compute_grad = grad(compute_loss_stateless_model)
gradinet = ft_compute_grad(params, buffers, train_poi_set[0][0].cuda(), torch.tensor(train_poi_set[0][1]).cuda())

This will return the gradient of the whole model. However, I only want the second last layers' gradient, like:

gradinet = ft_compute_grad(params, buffers, train_poi_set[0][0].cuda(), torch.tensor(train_poi_set[0][1]).cuda())[-2]

Although this method can also obtain the required gradient, it will cause a lot of unnecessary overhead. Is there any way to close the 'require_grad' of all previous layers? Thanks for your answer!

zou3519 commented 1 year ago

functorch.grad computes gradients w.r.t. to the first argument you pass it. This is currently params (all parameters in the model), but the solution is to pass it only the parameters that you want gradients of.

Some pseudocode.

from functorch import make_functional_with_buffers, vmap, grad
fmodel, params, buffers = make_functional_with_buffers(net,disable_autograd_tracking=True)

def compute_loss_stateless_model (last_layers_params, first_layers_params, buffers, sample, target):
    batch = sample.unsqueeze(0)
    targets = target.unsqueeze(0)

    # pseudocode: we need to put the params together back into a single params list
    # that fmodel can understand
    params = (*first_layers_params, *last_layers_params)

    predictions = fmodel(params, buffers, batch) 
    loss = criterion(predictions, targets)
    return loss

ft_compute_grad = grad(compute_loss_stateless_model)

# pseudocode: we need to split the params we want to compute gradients of from the params we don't
# want to compute gradients of.
first_layers_params, last_layers_params = partition(params)  

gradinet = ft_compute_grad(last_layers_params, first_layers_params, buffers, train_poi_set[0][0].cuda(), torch.tensor(train_poi_set[0][1]).cuda())
skxgogo commented 2 months ago

@zou3519 I have the similar question. But it's about jacrev. For example, I only want to compute the jacobi respect to the last layers. Can this work?