pytorch / functorch

functorch is JAX-like composable function transforms for PyTorch.
https://pytorch.org/functorch/
BSD 3-Clause "New" or "Revised" License
1.38k stars 102 forks source link

Different gradients for HyperNet training #1071

Open bkoyuncu opened 1 year ago

bkoyuncu commented 1 year ago

TLDR: Is there a way to optimize model created by combine_state_for_ensemble using torch.backward()?

Hi, I am using combine_state_for_ensemble for HyperNet training.

fmodel, fparams, fbuffers = combine_state_for_ensemble([HyperMLP() for i in range(K)])
[p.requires_grad_() for p in fparams];
weights_and_biases = vmap(fmodel)(fparams, fbuffers, z.expand(self.K,-1,-1)) #in which it parallizes over K

After I create the weights_and_biases, I put them into right shapes ws_and_bs and use as parameters of another ensemble.

fmodel, fparams, fbuffers = combine_state_for_ensemble([SimpleMLP() for i in range(K)])        
outputs = vmap(fmodel)(ws_and_bs, fbuffers, inputs)

This approach generates exactly the same outputs if I use loops instead of vmap. However, (somehow) their gradients are different.

loss = compute_loss(outputs)
loss.backward()

Do you have any idea why?

Update: It seems like ws_and_bs does not holding any gradient even though it is requires_grad.

Update2: It seems like I can forward by using stateless model with my generated weights but I cannot backprop from them using loss.backward(). Is there any trick that I can use?

zou3519 commented 1 year ago

Hey @bkoyuncu,

Do you have a longer script that we can use to reproduce the problem?

From what you are saying, it sounds like you want ws_and_bs to get gradients. You can do this by detaching them and creating new leaf tensors in the autograd graph:

def create_new_leaf(x):
  return x.detach().requires_grad_()

torch.utils._pytree.tree_map(create_new_leaf, ws_and_bs)

or by using Tensor.retain_grad.

bkoyuncu commented 1 year ago

Thank you so much for the suggestion @zou3519, I will check this and get back to you!