locuslab / torchdeq

Modern Fixed Point Systems using Pytorch
MIT License
82 stars 10 forks source link

Returned fixed point solution sometimes has requires_grad=False #8

Open Boshi-Wang opened 4 months ago

Boshi-Wang commented 4 months ago

Thanks for the great library!

I'm learning to use the library by writing a DEQ layer inside GPT-2. After loss.backward() some model parameters do not get gradients. I tracked the issue and it turns out that, the returned fixed point solution z_out[-1] by the DEQ layer (which is a tensor of shape (batch size, sequence length, hidden dimension)) has requires_grad set to False. Strangely, if I use a freshly-initialized GPT-2 model instead (without the pretrained weights), the issue is gone.

Specifically, this is my deq setup:

self.deq = get_deq(
    ift=True,
    f_solver='broyden', f_max_iter=15, f_tol=1e-3, f_stop_mode='rel',
    b_solver='broyden', b_max_iter=15, b_tol=1e-3, b_stop_mode='rel',
    )

if I print the requires_grad of the solution in each iteration before and after the implicit function, such as:

def layer_iter(hidden_states, input_hidden_states):
    print("before:", hidden_states.requires_grad)
    # forward pass through each transformer block, where input_hidden_states is the input embedding
    # ...
    print("after:", hidden_states.requires_grad)
    print("----")
    return hidden_states

func = lambda var: layer_iter(var, hidden_states_)    # hidden_states_ is the input after the embedding layer
zeros_ = torch.zeros(*hidden_states_.shape, requires_grad=True).to(hidden_states_.device)
z_out, info = self.deq(func, zeros_)
print("z_out[-1].requires_grad:", z_out[-1].requires_grad)
print(info)
hidden_states = z_out[-1]

and with the main scripts as


import torch
from torchdeq import get_deq
import matplotlib.pyplot as plt
from transformers import GPT2Tokenizer, GPT2Model

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

model_name = 'gpt2'
tokenizer = GPT2Tokenizer.from_pretrained(model_name)
tokenizer.pad_token = tokenizer.eos_token
model = GPT2Model.from_pretrained(model_name, attn_pdrop=0.0, embd_pdrop=0.0, resid_pdrop=0.0, summary_first_dropout=0.0)
model.to(device)

batch = ["we", "we"]
inputs = tokenizer(batch, return_tensors='pt', padding=True, truncation=True)
inputs = {key: value.to(device) for key, value in inputs.items()}

outputs = model(**inputs, use_cache=False)

Then I will get

before: True
after: False
----
before: False
after: False
----
before: False
after: False
----
before: False
after: False
----
before: False
after: False
----
before: False
after: False
----
before: False
after: False
----
before: False
after: False
----
before: False
after: False
----
before: False
after: False
----
z_out[-1].requires_grad: False
{'abs_lowest': tensor([914.4200, 914.4200], device='cuda:0'), 'rel_lowest': tensor([6.5291e-05, 6.5291e-05], device='cuda:0'), 'abs_trace': tensor([[1.0000e+08, 1.4693e+03, 1.1808e+03, 2.8000e+03, 2.1719e+03, 9.1442e+02,
         1.0195e+03, 1.0512e+03, 1.0572e+03, 1.0586e+03, 9.1442e+02, 9.1442e+02,
         9.1442e+02, 9.1442e+02, 9.1442e+02, 9.1442e+02],
        [1.0000e+08, 1.4693e+03, 1.1808e+03, 2.8000e+03, 2.1719e+03, 9.1442e+02,
         1.0195e+03, 1.0512e+03, 1.0572e+03, 1.0586e+03, 9.1442e+02, 9.1442e+02,
         9.1442e+02, 9.1442e+02, 9.1442e+02, 9.1442e+02]], device='cuda:0'), 'rel_trace': tensor([[1.0000e+08, 1.3185e+00, 1.0195e+00, 6.6233e-01, 7.3519e-01, 2.3931e-01,
         4.5128e-01, 1.2702e-02, 2.3903e-03, 6.5291e-05, 6.5291e-05, 6.5291e-05,
         6.5291e-05, 6.5291e-05, 6.5291e-05, 6.5291e-05],
        [1.0000e+08, 1.3185e+00, 1.0195e+00, 6.6233e-01, 7.3519e-01, 2.3931e-01,
         4.5128e-01, 1.2702e-02, 2.3903e-03, 6.5291e-05, 6.5291e-05, 6.5291e-05,
         6.5291e-05, 6.5291e-05, 6.5291e-05, 6.5291e-05]], device='cuda:0'), 'nstep': tensor([9., 9.], device='cuda:0'), 'sradius': tensor([0.])}

where it seems that the model does converge in 9 iterations but the returned hidden states have requires_grad=False, so the model's parameters (besides those after the DEQ layer) do not get gradients. I tried manually setting z_out[-1].requires_grad=True but this doesn't help; after loss.backward() the .grad is still None for those parameters.

Intriguingly, if I use a freshly-initialized GPT-2 then the issue seems to go away:

model_name = 'gpt2'
tokenizer = GPT2Tokenizer.from_pretrained(model_name)
tokenizer.pad_token = tokenizer.eos_token
model = GPT2Model.from_pretrained(model_name, attn_pdrop=0.0, embd_pdrop=0.0, resid_pdrop=0.0, summary_first_dropout=0.0)

# initialize the weights randomly
config = model.config
model = GPT2Model(config)

batch = ["we", "we"]
inputs = tokenizer(batch, return_tensors='pt', padding=True, truncation=True)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
inputs = {key: value.to(device) for key, value in inputs.items()}
model.to(device)

outputs = model(**inputs, use_cache=False)

and I get

before: True
after: False
----
before: False
after: False
----
before: False
after: False
----
before: False
after: False
----
before: False
after: False
----
before: False
after: False
----
before: False
after: False
----
before: False
after: True
----
z_out[-1].requires_grad: True
{'abs_lowest': tensor([9.0937, 9.0937], device='cuda:0'), 'rel_lowest': tensor([0.0002, 0.0002], device='cuda:0'), 'abs_trace': tensor([[1.0000e+08, 9.0937e+00, 9.2320e+00, 9.3630e+00, 9.4192e+00, 9.4344e+00,
         9.4367e+00, 9.0937e+00, 9.0937e+00, 9.0937e+00, 9.0937e+00, 9.0937e+00,
         9.0937e+00, 9.0937e+00, 9.0937e+00, 9.0937e+00],
        [1.0000e+08, 9.0937e+00, 9.2320e+00, 9.3630e+00, 9.4192e+00, 9.4344e+00,
         9.4367e+00, 9.0937e+00, 9.0937e+00, 9.0937e+00, 9.0937e+00, 9.0937e+00,
         9.0937e+00, 9.0937e+00, 9.0937e+00, 9.0937e+00]], device='cuda:0'), 'rel_trace': tensor([[1.0000e+08, 5.5195e-01, 2.3347e-01, 9.2865e-02, 2.5290e-02, 3.8527e-03,
         1.9339e-04, 1.9339e-04, 1.9339e-04, 1.9339e-04, 1.9339e-04, 1.9339e-04,
         1.9339e-04, 1.9339e-04, 1.9339e-04, 1.9339e-04],
        [1.0000e+08, 5.5195e-01, 2.3347e-01, 9.2865e-02, 2.5290e-02, 3.8527e-03,
         1.9339e-04, 1.9339e-04, 1.9339e-04, 1.9339e-04, 1.9339e-04, 1.9339e-04,
         1.9339e-04, 1.9339e-04, 1.9339e-04, 1.9339e-04]], device='cuda:0'), 'nstep': tensor([6., 6.], device='cuda:0'), 'sradius': tensor([-1.])}
max/min/mean steps: 6.0, 6.0, 6.0

where it can be seen that requires_grad becomes True now. Also, apart from requires_grad, it seems the sradius in info is -1 now instead of 0, which I'm not sure if it's related here.

I wonder if you have ideas about why this happens. Would appreciate it!