pytorch / functorch

functorch is JAX-like composable function transforms for PyTorch.
BSD 3-Clause "New" or "Revised" License
1.38k stars 102 forks source link

batching over model parameters #1094

Open LeanderK opened 1 year ago

LeanderK commented 1 year ago

I have a use-case for functorch. I would like to check possible iterations of model parameters in a very efficient way (I want to eliminate the loop). Here's an example code for a simplified case I got it working:

linear = torch.nn.Linear(10,2)
default_weight =
sample_input = torch.rand(3,10)
sample_add = torch.rand_like(default_weight)
def interpolate_weights(alpha):
    with torch.no_grad():
        res_weight = torch.nn.Parameter(default_weight + alpha*sample_add)
        linear.weight = res_weight
        return linear(sample_input)

now I could do for alpha in, 1.0, 100) but I want to vectorise this loop since my code is prohibitively slow. Is functorch here applicable? Executing:

alphas = torch.linspace(0.0, 1.0, 100)

works, but how to do something similar for a simple resnet does not work. I've tried using load_state_dict but that's not working:

from torchvision import models
model_resnet = models.resnet18(pretrained=True)

named_params = list(model_resnet.named_parameters())
named_params_data = [(n, for (n,p) in named_params]

sample_data = torch.rand(10,3,224,244)

def test_resnet(new_params):
    def interpolate(alpha):
        with torch.no_grad():
            p_dict = {name:(old + alpha*new_params[i]) for i,(name, old) in enumerate(named_params_data)}
            model_resnet.load_state_dict(p_dict, strict=False)
            out = model_resnet(sample_data)
            return out
    return interpolate

rand_tensor = [torch.rand_like(p) for n,p in named_params_data]

to_vamp_resnet = test_thing(rand_tensor)

results in:

While copying the parameter named "fc.bias", whose dimensions in the model are torch.Size([1000]) and whose dimensions in the checkpoint are torch.Size([1000]), an exception occurred : ('vmap: inplace arithmetic(self, *extra_args) is not possible because there exists a Tensorotherin extra_args that has more elements thanself. This happened due tootherbeing vmapped over butselfnot being vmapped over in a vmap. Please try to use out-of-place operators instead of inplace arithmetic. If said operator is being called inside the PyTorch framework, please file a bug report instead.',).

LeanderK commented 1 year ago

is this a legal way to solve this? it doesn't give me an error but I am very unsure why this now works.

def test_resnet_2(new_params):
    def interpolate(alpha):
        with torch.no_grad():
            for i, (name, old_p) in enumerate(named_params_data):
                new_p = new_params[i]
                parame_names = name.split(".")
                current = model_resnet
                for p in parame_names[:-1]:
                    current = getattr(current, p)
                setattr(current, parame_names[-1], torch.nn.Parameter(old_p + alpha*new_p))

            out = model_resnet(sample_data)

            for i, (name, old_p) in enumerate(named_params_data):
                parame_names = name.split(".")
                current = model_resnet
                for p in parame_names[:-1]:
                    current = getattr(current, p)
                setattr(current, parame_names[-1], torch.nn.Parameter(old_p))
            return out
    return interpolate

to_vamp_resnet = test_thing2(rand_tensor)
test_out2 = vmap(to_vamp_resnet)(alphas)

EDIT: found an even simple solution. This is the correct approach, right?

def test_resnet_4(new_params, sample_data, model_resnet):
    func_model, params, buff = make_functional_with_buffers(model_resnet, disable_autograd_tracking=True)
    def interpolate(alpha):
        with torch.no_grad():
            interpol_params = [torch.nn.Parameter(old_p + alpha*new_params[i]) for i, old_p in enumerate(params)]

            out = func_model(interpol_params, buff, sample_data)
            return out
    return interpolate

to_vamp_resnet = test_resnet_4(rand_tensor, sample_data, model_resnet)
test_out2 = vmap(to_vamp_resnet)(alphas)
samdow commented 1 year ago

Hi @LeanderK! Thanks for the interesting issue! Since it sounds like this works, that's a totally fine way of doing it!

One thing that might come up is if you do N runs of this model (instead of 1), it will be faster to do something similar to the ensembling API since in your version you would be building the new parameters N times and this way you'll only build them once and then combine them. This is also useful if you want to train the model (batch norm should work with the ensemble)

For this use case, since it looks like you want to have very specific initializations, it this might be better to riff on the idea of the ensemble API

def test_resnet_4(func_model, buff, sample_data):
  def interpolate(interpol_params):
      with torch.no_grad():
          out = func_model(interpol_params, buff, sample_data)
          return out
  return interpolate


func_model, params, buff = make_functional_with_buffers(model_resnet, disable_autograd_tracking=True)
interpol_params = [[torch.nn.Parameter(old_p + alpha*rand_tensor[i]) for i, old_p in enumerate(params)] for alpha in alphas]
interpol_params = [torch.stack(i) for i in zip(*interpol_params)] # this is basically what the ensemble API is doing
to_vmap_resnet = test_resnet_4(func_model, buff, sample_data)
test_out2 = vmap(to_vmap_resnet)(interpol_params)

Then, if you want to train, you can also expand the buffers and vmap across them along with interpol_params so that batch norm works

Hope that helps! We are also looking at changing the module API to help rationalize some of the functorch API with the PyTorch API soon. If you're using the nightly build, I can point you to the new API if you're curious