Hi Pytorch team, recently I need to calculate per sample's gradient with respect to part of model's parameters. The problem is that for the toy example, it works. But for the Wide & Deep model, it doesn't work and returns me all 0 gradients. I don't know why.
Here is the toy example:
import torch
from functorch import grad
from functorch import make_functional_with_buffers, vmap, grad
import torch.nn.functional as F
import torch.nn as nn
class SimpleCNN(nn.Module):
def __init__(self):
super(SimpleCNN, self).__init__()
self.conv1 = nn.Conv2d(1, 32, 3, 1)
self.conv2 = nn.Conv2d(32, 64, 3, 1)
self.fc1 = nn.Linear(9216, 128)
self.fc2 = nn.Linear(128, 10)
def forward(self, x):
x = self.conv1(x)
x = F.relu(x)
x = self.conv2(x)
x = F.relu(x)
x = F.max_pool2d(x, 2)
x = torch.flatten(x, 1)
x = self.fc1(x)
x = F.relu(x)
x = self.fc2(x)
output = F.log_softmax(x, dim=1)
output = x
return output
device = 'cuda'
num_models = 10
batch_size = 64
data = torch.randn(batch_size, 1, 28, 28, device=device)
targets = torch.randint(10, (64,), device=device)
model = SimpleCNN().to(device=device)
model=model.eval()
fmodel, params, buffers = make_functional_with_buffers(model)
def loss_fn(predictions, targets):
return F.nll_loss(predictions, targets)
def compute_loss_stateless_model (params_tograd,params, buffers, sample, target):
for key, value in params_tograd.items():
params[key]=value
batch = sample.unsqueeze(0)
targets = target.unsqueeze(0)
predictions = fmodel(params, buffers, batch)
loss = loss_fn(predictions, targets)
return loss
ft_compute_grad = grad(compute_loss_stateless_model)
ft_compute_sample_grad = vmap(ft_compute_grad, in_dims=(None,None, None, 0, 0))
params_tograd={}
for i in [-2,0]:
params_tograd[i]=params[i]
ft_per_sample_grads = ft_compute_sample_grad(params_tograd,[p for p in params], buffers, data, targets)
print(ft_per_sample_grads)
The result is :
However, when I apply this method to the real scenario, it doesn't works and all return 0 gradient.
model.load_state_dict(w_tao)
fmodel, params, buffers = make_functional_with_buffers(model)
def loss_fn(predictions, targets):
return F.mse_loss(predictions, targets)
def compute_loss_stateless_model (params_tograd,params, buffers, sample, target):
for key, value in params_tograd.items():
params[key]=value
batch = sample.unsqueeze(0)
targets = target.unsqueeze(0)
predictions = fmodel(params, buffers, batch)
loss = loss_fn(predictions, targets)
return loss
ft_compute_grad = grad(compute_loss_stateless_model)
ft_compute_sample_grad = vmap(ft_compute_grad, in_dims=(None,None,None, 0, 0))
params_tograd={}
if dataset=='lastfm-1k':
params_tograd[-1]=params[-1]
params_tograd[4]=params[4]
else:
params_tograd[0]=params[0]
prod_all=[]
for batch_idx, (inputs, targets) in tqdm(enumerate(train_data_loader)):
inputs, targets = inputs.to(self.device).float(), targets.to(self.device).float()
ft_per_sample_grads = ft_compute_sample_grad(params_tograd,[p for p in params], buffers, inputs, targets)
print(ft_per_sample_grads)
sys.exit()
if dataset!='lastfm-1k':
params_grads=ft_per_sample_grads[-1].reshape(ft_per_sample_grads[-1].shape[0] ,-1)
elif dataset=='lastfm-1k':
params_grad_dnn=ft_per_sample_grads[-1].reshape(ft_per_sample_grads[-1].shape[0] ,-1)
params_grad_linear=ft_per_sample_grads[4].reshape(ft_per_sample_grads[4].shape[0] ,-1)
params_grads=torch.cat([params_grad_dnn,params_grad_linear],-1)
prod=torch.mm(params_grads,grad_mean.unsqueeze(1)).squeeze().detach().to('cpu').numpy()
prod_all.extend(prod)
return dict(zip(range(1,len(prod_all)+1), prod_all))
Also, in the Wide & Deep module, the shape of linear_logit should be (batch_size,1), but when apply this method, the error will happened here, and the system said the shape of linear_logit and sparse_feat_logit is not match, here I attached the print out result. (I suppose when using this method, the X.shape[0]=0, but why?)
linear_logit = torch.zeros([X.shape[0], 1]).to(self.device)
if len(sparse_embedding_list) > 0:
#torch.Size([1000, 1, 7])
sparse_embedding_cat = torch.cat(sparse_embedding_list, dim=-1)
if sparse_feat_refine_weight is not None:
# w_{x,i}=m_{x,i} * w_i (in IFM and DIFM)
sparse_embedding_cat = sparse_embedding_cat * sparse_feat_refine_weight.unsqueeze(1)
sparse_feat_logit = torch.sum(sparse_embedding_cat, dim=-1, keepdim=False)
try:
linear_logit += sparse_feat_logit
except:
print(linear_logit.shape)
print(sparse_feat_logit.shape)
print('linear_logit\n',linear_logit)
print('sparse_feat_logit\n',sparse_feat_logit)
sys.exit()
linear_logit=sparse_feat_logit
Hi Pytorch team, recently I need to calculate per sample's gradient with respect to part of model's parameters. The problem is that for the toy example, it works. But for the Wide & Deep model, it doesn't work and returns me all 0 gradients. I don't know why.
Here is the toy example:
The result is :
However, when I apply this method to the real scenario, it doesn't works and all return 0 gradient.
Also, in the Wide & Deep module, the shape of linear_logit should be (batch_size,1), but when apply this method, the error will happened here, and the system said the shape of linear_logit and sparse_feat_logit is not match, here I attached the print out result. (I suppose when using this method, the X.shape[0]=0, but why?)