pytorch / functorch

functorch is JAX-like composable function transforms for PyTorch.
https://pytorch.org/functorch/
BSD 3-Clause "New" or "Revised" License
1.4k stars 102 forks source link

Per-sample-gradient: Get gradient 0 when using grad(params_tograd, params) with respect to part of model's parameters #1122

Open Ancientshi opened 1 year ago

Ancientshi commented 1 year ago

Hi Pytorch team, recently I need to calculate per sample's gradient with respect to part of model's parameters. The problem is that for the toy example, it works. But for the Wide & Deep model, it doesn't work and returns me all 0 gradients. I don't know why.

Here is the toy example:

import torch
from functorch import grad
from functorch import make_functional_with_buffers, vmap, grad
import torch.nn.functional as F
import torch.nn as nn

class SimpleCNN(nn.Module):
    def __init__(self):
        super(SimpleCNN, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, 3, 1)
        self.conv2 = nn.Conv2d(32, 64, 3, 1)
        self.fc1 = nn.Linear(9216, 128)
        self.fc2 = nn.Linear(128, 10)

    def forward(self, x):
        x = self.conv1(x)
        x = F.relu(x)
        x = self.conv2(x)
        x = F.relu(x)
        x = F.max_pool2d(x, 2)
        x = torch.flatten(x, 1)
        x = self.fc1(x)
        x = F.relu(x)
        x = self.fc2(x)
        output = F.log_softmax(x, dim=1)
        output = x
        return output

device = 'cuda'
num_models = 10
batch_size = 64

data = torch.randn(batch_size, 1, 28, 28, device=device)
targets = torch.randint(10, (64,), device=device)

model = SimpleCNN().to(device=device)
model=model.eval()
fmodel, params, buffers = make_functional_with_buffers(model)

def loss_fn(predictions, targets):
    return F.nll_loss(predictions, targets)

def compute_loss_stateless_model (params_tograd,params, buffers, sample, target):
    for key, value in params_tograd.items():
        params[key]=value
    batch = sample.unsqueeze(0)
    targets = target.unsqueeze(0)

    predictions = fmodel(params, buffers, batch) 
    loss = loss_fn(predictions, targets)
    return loss

ft_compute_grad = grad(compute_loss_stateless_model)

ft_compute_sample_grad = vmap(ft_compute_grad, in_dims=(None,None, None, 0, 0))

params_tograd={}
for i in [-2,0]:
    params_tograd[i]=params[i]
ft_per_sample_grads = ft_compute_sample_grad(params_tograd,[p for p in params], buffers, data, targets)
print(ft_per_sample_grads)

The result is :

image

However, when I apply this method to the real scenario, it doesn't works and all return 0 gradient.

      model.load_state_dict(w_tao)
      fmodel, params, buffers = make_functional_with_buffers(model) 

      def loss_fn(predictions, targets):
          return F.mse_loss(predictions, targets)

      def compute_loss_stateless_model (params_tograd,params, buffers, sample, target):
          for key, value in params_tograd.items():
              params[key]=value
          batch = sample.unsqueeze(0)
          targets = target.unsqueeze(0)

          predictions = fmodel(params, buffers, batch) 
          loss = loss_fn(predictions, targets)
          return loss

      ft_compute_grad = grad(compute_loss_stateless_model)
      ft_compute_sample_grad = vmap(ft_compute_grad, in_dims=(None,None,None, 0, 0))
      params_tograd={}
      if dataset=='lastfm-1k':
          params_tograd[-1]=params[-1]
          params_tograd[4]=params[4]
      else:
          params_tograd[0]=params[0]

      prod_all=[]
      for batch_idx, (inputs, targets) in tqdm(enumerate(train_data_loader)):
          inputs, targets = inputs.to(self.device).float(), targets.to(self.device).float()

          ft_per_sample_grads = ft_compute_sample_grad(params_tograd,[p for p in params], buffers, inputs, targets)
          print(ft_per_sample_grads)
          sys.exit()
          if dataset!='lastfm-1k':
              params_grads=ft_per_sample_grads[-1].reshape(ft_per_sample_grads[-1].shape[0] ,-1)
          elif dataset=='lastfm-1k':
              params_grad_dnn=ft_per_sample_grads[-1].reshape(ft_per_sample_grads[-1].shape[0] ,-1)
              params_grad_linear=ft_per_sample_grads[4].reshape(ft_per_sample_grads[4].shape[0] ,-1)
              params_grads=torch.cat([params_grad_dnn,params_grad_linear],-1) 

          prod=torch.mm(params_grads,grad_mean.unsqueeze(1)).squeeze().detach().to('cpu').numpy()
          prod_all.extend(prod)         
      return dict(zip(range(1,len(prod_all)+1), prod_all))
image image

Also, in the Wide & Deep module, the shape of linear_logit should be (batch_size,1), but when apply this method, the error will happened here, and the system said the shape of linear_logit and sparse_feat_logit is not match, here I attached the print out result. (I suppose when using this method, the X.shape[0]=0, but why?)

        linear_logit = torch.zeros([X.shape[0], 1]).to(self.device)
        if len(sparse_embedding_list) > 0:
            #torch.Size([1000, 1, 7])
            sparse_embedding_cat = torch.cat(sparse_embedding_list, dim=-1)
            if sparse_feat_refine_weight is not None:
                # w_{x,i}=m_{x,i} * w_i (in IFM and DIFM)
                sparse_embedding_cat = sparse_embedding_cat * sparse_feat_refine_weight.unsqueeze(1)

            sparse_feat_logit = torch.sum(sparse_embedding_cat, dim=-1, keepdim=False)
            try:     
                linear_logit += sparse_feat_logit
            except:
                print(linear_logit.shape)
                print(sparse_feat_logit.shape)
                print('linear_logit\n',linear_logit)
                print('sparse_feat_logit\n',sparse_feat_logit)
                sys.exit()
                linear_logit=sparse_feat_logit
image
Ancientshi commented 1 year ago

When I want to calculate the gradients of the embedding layer,:

RuntimeError: a leaf Variable that requires grad is being used in an in-place operation.