f-dangel / backpack

BackPACK - a backpropagation package built on top of PyTorch which efficiently computes quantities other than the gradient.
https://backpack.pt/
MIT License
555 stars 55 forks source link

Extend part of the model #232

Open haonan3 opened 2 years ago

haonan3 commented 2 years ago

Hello,

I am wondering is it possible to extend part of the model, if I only want to get the batch gradient of the last several layers?

I think model = extend(model) will waste memory if only the batch gradient of the last several layers is needed.

For example, if I only want to extend the last two layers (let's say the last two layers are fc1 and fc2) of a large model, can I do something like this:

model.fc1 = extend(model.fc1)
model.fc2 = extend(model.fc2)
f-dangel commented 2 years ago

Hi,

for individual gradients it is indeed sufficient to only extend the modules whose parameters you're interested in.

Best, Felix