Open pmzzs opened 1 year ago
functorch.grad
computes gradients w.r.t. to the first argument you pass it. This is currently params
(all parameters in the model), but the solution is to pass it only the parameters that you want gradients of.
Some pseudocode.
from functorch import make_functional_with_buffers, vmap, grad
fmodel, params, buffers = make_functional_with_buffers(net,disable_autograd_tracking=True)
def compute_loss_stateless_model (last_layers_params, first_layers_params, buffers, sample, target):
batch = sample.unsqueeze(0)
targets = target.unsqueeze(0)
# pseudocode: we need to put the params together back into a single params list
# that fmodel can understand
params = (*first_layers_params, *last_layers_params)
predictions = fmodel(params, buffers, batch)
loss = criterion(predictions, targets)
return loss
ft_compute_grad = grad(compute_loss_stateless_model)
# pseudocode: we need to split the params we want to compute gradients of from the params we don't
# want to compute gradients of.
first_layers_params, last_layers_params = partition(params)
gradinet = ft_compute_grad(last_layers_params, first_layers_params, buffers, train_poi_set[0][0].cuda(), torch.tensor(train_poi_set[0][1]).cuda())
@zou3519 I have the similar question. But it's about jacrev. For example, I only want to compute the jacobi respect to the last layers. Can this work?
This will return the gradient of the whole model. However, I only want the second last layers' gradient, like:
Although this method can also obtain the required gradient, it will cause a lot of unnecessary overhead. Is there any way to close the 'require_grad' of all previous layers? Thanks for your answer!